From 50a5f9cdc20c51a1651400f34351c6da9f44e218 Mon Sep 17 00:00:00 2001 From: Gian Merlino Date: Tue, 1 Oct 2024 14:38:55 -0700 Subject: [PATCH] MSQ profile for Brokers and Historicals. (#17140) This patch adds a profile of MSQ named "Dart" that runs on Brokers and Historicals, and which is compatible with the standard SQL query API. For more high-level description, and notes on future work, refer to #17139. This patch contains the following changes, grouped into packages. Controller (org.apache.druid.msq.dart.controller): The controller runs on Brokers. Main classes are, - DartSqlResource, which serves /druid/v2/sql/dart/. - DartSqlEngine and DartQueryMaker, the entry points from SQL that actually run the MSQ controller code. - DartControllerContext, which configures the MSQ controller. - DartMessageRelays, which sets up relays (see "message relays" below) to read messages from workers' DartControllerClients. - DartTableInputSpecSlicer, which assigns work based on a TimelineServerView. Worker (org.apache.druid.msq.dart.worker) The worker runs on Historicals. Main classes are, - DartWorkerResource, which supplies the regular MSQ WorkerResource, plus Dart-specific APIs. - DartWorkerRunner, which runs MSQ worker code. - DartWorkerContext, which configures the MSQ worker. - DartProcessingBuffersProvider, which provides processing buffers from sliced-up merge buffers. - DartDataSegmentProvider, which provides segments from the Historical's local cache. Message relays (org.apache.druid.messages): To avoid the need for Historicals to contact Brokers during a query, which would create opportunities for queries to get stuck, all connections are opened from Broker to Historical. This is made possible by a message relay system, where the relay server (worker) has an outbox of messages. The relay client (controller) connects to the outbox and retrieves messages. Code for this system lives in the "server" package to keep it separate from the MSQ extension and make it easier to maintain. The worker-to-controller ControllerClient is implemented using message relays. Other changes: - Controller: Added the method "hasWorker". Used by the ControllerMessageListener to notify the appropriate controllers when a worker fails. - WorkerResource: No longer tries to respond more than once in the "httpGetChannelData" API. This comes up when a response due to resolved future is ready at about the same time as a timeout occurs. - MSQTaskQueryMaker: Refactor to separate out some useful functions for reuse in DartQueryMaker. - SqlEngine: Add "queryContext" to "resultTypeForSelect" and "resultTypeForInsert". This allows the DartSqlEngine to modify result format based on whether a "fullReport" context parameter is set. - LimitedOutputStream: New utility class. Used when in "fullReport" mode. - TimelineServerView: Add getDruidServerMetadata as a performance optimization. - CliHistorical: Add SegmentWrangler, so it can query inline data, lookups, etc. - ServiceLocation: Add "fromUri" method, relocating some code from ServiceClientImpl. - FixedServiceLocator: New locator for a fixed set of service locations. Useful for URI locations. --- .../java/org/apache/druid/msq/dart/Dart.java | 37 + .../dart/DartResourcePermissionMapper.java | 57 ++ .../msq/dart/controller/ControllerHolder.java | 166 ++++ .../controller/ControllerMessageListener.java | 68 ++ .../controller/DartControllerContext.java | 246 ++++++ .../DartControllerContextFactory.java | 31 + .../DartControllerContextFactoryImpl.java | 83 ++ .../controller/DartControllerRegistry.java | 72 ++ .../DartMessageRelayFactoryImpl.java | 82 ++ .../dart/controller/DartMessageRelays.java | 40 + .../controller/DartTableInputSpecSlicer.java | 292 +++++++ .../dart/controller/DartWorkerManager.java | 200 +++++ .../dart/controller/http/DartQueryInfo.java | 189 +++++ .../dart/controller/http/DartSqlResource.java | 275 +++++++ .../controller/http/GetQueriesResponse.java | 73 ++ .../messages/ControllerMessage.java | 49 ++ .../controller/messages/DoneReadingInput.java | 101 +++ .../messages/PartialKeyStatistics.java | 118 +++ .../controller/messages/ResultsComplete.java | 118 +++ .../dart/controller/messages/WorkerError.java | 96 +++ .../controller/messages/WorkerWarning.java | 96 +++ .../dart/controller/sql/DartQueryMaker.java | 484 +++++++++++ .../dart/controller/sql/DartSqlClient.java | 42 + .../controller/sql/DartSqlClientFactory.java | 30 + .../sql/DartSqlClientFactoryImpl.java | 64 ++ .../controller/sql/DartSqlClientImpl.java | 57 ++ .../dart/controller/sql/DartSqlClients.java | 118 +++ .../dart/controller/sql/DartSqlEngine.java | 181 +++++ .../msq/dart/guice/DartControllerConfig.java | 44 + .../DartControllerMemoryManagementModule.java | 64 ++ .../msq/dart/guice/DartControllerModule.java | 134 ++++ .../druid/msq/dart/guice/DartModules.java | 37 + .../msq/dart/guice/DartWorkerConfig.java | 53 ++ .../DartWorkerMemoryManagementModule.java | 102 +++ .../msq/dart/guice/DartWorkerModule.java | 153 ++++ .../msq/dart/worker/DartControllerClient.java | 143 ++++ .../dart/worker/DartDataSegmentProvider.java | 111 +++ .../msq/dart/worker/DartFrameContext.java | 178 ++++ .../worker/DartProcessingBuffersProvider.java | 94 +++ .../msq/dart/worker/DartQueryableSegment.java | 89 ++ .../msq/dart/worker/DartWorkerClient.java | 210 +++++ .../msq/dart/worker/DartWorkerContext.java | 248 ++++++ .../msq/dart/worker/DartWorkerFactory.java | 33 + .../dart/worker/DartWorkerFactoryImpl.java | 142 ++++ .../dart/worker/DartWorkerRetryPolicy.java | 90 +++ .../msq/dart/worker/DartWorkerRunner.java | 349 ++++++++ .../druid/msq/dart/worker/WorkerId.java | 157 ++++ .../msq/dart/worker/http/DartWorkerInfo.java | 110 +++ .../dart/worker/http/DartWorkerResource.java | 181 +++++ .../dart/worker/http/GetWorkersResponse.java | 64 ++ .../org/apache/druid/msq/exec/Controller.java | 8 + .../apache/druid/msq/exec/ControllerImpl.java | 10 + .../org/apache/druid/msq/exec/WorkerImpl.java | 6 + .../apache/druid/msq/exec/WorkerManager.java | 7 +- .../msq/indexing/TaskReportQueryListener.java | 23 +- .../msq/indexing/error/CanceledFault.java | 10 + .../error/ColumnNameRestrictedFault.java | 9 + .../error/ColumnTypeNotSupportedFault.java | 10 + .../msq/indexing/error/MSQErrorReport.java | 26 + .../druid/msq/indexing/error/MSQFault.java | 14 + .../error/QueryNotSupportedFault.java | 10 + .../druid/msq/rpc/BaseWorkerClientImpl.java | 4 +- .../apache/druid/msq/rpc/WorkerResource.java | 15 +- .../druid/msq/sql/MSQTaskQueryMaker.java | 145 ++-- .../druid/msq/sql/MSQTaskSqlEngine.java | 22 +- .../msq/util/MSQTaskQueryMakerUtils.java | 5 +- ...rg.apache.druid.initialization.DruidModule | 4 + .../DartTableInputSpecSlicerTest.java | 488 +++++++++++ .../controller/DartWorkerManagerTest.java | 179 +++++ .../controller/http/DartQueryInfoTest.java | 32 + .../controller/http/DartSqlResourceTest.java | 757 ++++++++++++++++++ .../http/GetQueriesResponseTest.java | 61 ++ .../messages/ControllerMessageTest.java | 90 +++ .../controller/sql/DartSqlClientImplTest.java | 118 +++ .../dart/worker/DartQueryableSegmentTest.java | 32 + .../msq/dart/worker/DartWorkerRunnerTest.java | 314 ++++++++ .../druid/msq/dart/worker/WorkerIdTest.java | 102 +++ .../dart/worker/http/DartWorkerInfoTest.java | 32 + .../worker/http/GetWorkersResponseTest.java | 58 ++ .../apache/druid/msq/test/MSQTestBase.java | 3 +- .../msq/test/MSQTestControllerContext.java | 12 +- .../test/MSQTestOverlordServiceClient.java | 3 +- .../druid/msq/test/MSQTestWorkerClient.java | 6 +- .../apache/druid/common/guava/FutureBox.java | 77 ++ .../apache/druid/io/LimitedOutputStream.java | 98 +++ .../druid/common/guava/FutureBoxTest.java | 75 ++ .../druid/io/LimitedOutputStreamTest.java | 72 ++ .../apache/druid/client/BrokerServerView.java | 14 + .../druid/client/TimelineServerView.java | 18 +- .../druid/discovery/DataServerClient.java | 4 +- .../apache/druid/messages/MessageBatch.java | 112 +++ .../messages/client/MessageListener.java | 50 ++ .../druid/messages/client/MessageRelay.java | 243 ++++++ .../messages/client/MessageRelayClient.java | 43 + .../client/MessageRelayClientImpl.java | 85 ++ .../messages/client/MessageRelayFactory.java | 30 + .../druid/messages/client/MessageRelays.java | 143 ++++ .../apache/druid/messages/package-info.java | 44 + .../messages/server/MessageRelayMonitor.java | 82 ++ .../messages/server/MessageRelayResource.java | 196 +++++ .../apache/druid/messages/server/Outbox.java | 68 ++ .../druid/messages/server/OutboxImpl.java | 209 +++++ .../apache/druid/rpc/FixedServiceLocator.java | 60 ++ .../druid/rpc/FixedSetServiceLocator.java | 90 --- .../apache/druid/rpc/ServiceClientImpl.java | 46 +- .../org/apache/druid/rpc/ServiceLocation.java | 64 ++ .../druid/client/BrokerServerViewTest.java | 11 +- .../druid/messages/MessageBatchTest.java | 51 ++ .../client/MessageRelayClientImplTest.java | 92 +++ .../messages/client/MessageRelaysTest.java | 222 +++++ .../druid/messages/server/OutboxImplTest.java | 213 +++++ ...Test.java => FixedServiceLocatorTest.java} | 49 +- .../druid/rpc/ServiceClientImplTest.java | 8 - .../apache/druid/rpc/ServiceLocationTest.java | 40 + .../org/apache/druid/cli/CliHistorical.java | 2 + .../sql/calcite/planner/IngestHandler.java | 3 +- .../sql/calcite/planner/QueryHandler.java | 3 +- .../sql/calcite/run/NativeSqlEngine.java | 12 +- .../druid/sql/calcite/run/SqlEngine.java | 14 +- .../druid/sql/calcite/view/ViewSqlEngine.java | 12 +- .../apache/druid/sql/http/SqlResource.java | 35 +- .../sql/calcite/CalciteScanSignatureTest.java | 12 +- .../sql/calcite/IngestionTestSqlEngine.java | 12 +- .../calcite/util/TestTimelineServerView.java | 2 - 124 files changed, 11309 insertions(+), 273 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/Dart.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/DartResourcePermissionMapper.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerHolder.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerMessageListener.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactory.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactoryImpl.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerRegistry.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelayFactoryImpl.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelays.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicer.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartWorkerManager.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartQueryInfo.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartSqlResource.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponse.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ControllerMessage.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/DoneReadingInput.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/PartialKeyStatistics.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ResultsComplete.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerError.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerWarning.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartQueryMaker.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClient.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactory.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactoryImpl.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImpl.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClients.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlEngine.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerConfig.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerMemoryManagementModule.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerModule.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartModules.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerConfig.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerMemoryManagementModule.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerModule.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartControllerClient.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataSegmentProvider.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartQueryableSegment.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerClient.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactory.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactoryImpl.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRetryPolicy.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRunner.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/WorkerId.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfo.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerResource.java create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponse.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicerTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartWorkerManagerTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartQueryInfoTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponseTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/messages/ControllerMessageTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImplTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartQueryableSegmentTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartWorkerRunnerTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/WorkerIdTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfoTest.java create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponseTest.java create mode 100644 processing/src/main/java/org/apache/druid/common/guava/FutureBox.java create mode 100644 processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java create mode 100644 processing/src/test/java/org/apache/druid/common/guava/FutureBoxTest.java create mode 100644 processing/src/test/java/org/apache/druid/io/LimitedOutputStreamTest.java create mode 100644 server/src/main/java/org/apache/druid/messages/MessageBatch.java create mode 100644 server/src/main/java/org/apache/druid/messages/client/MessageListener.java create mode 100644 server/src/main/java/org/apache/druid/messages/client/MessageRelay.java create mode 100644 server/src/main/java/org/apache/druid/messages/client/MessageRelayClient.java create mode 100644 server/src/main/java/org/apache/druid/messages/client/MessageRelayClientImpl.java create mode 100644 server/src/main/java/org/apache/druid/messages/client/MessageRelayFactory.java create mode 100644 server/src/main/java/org/apache/druid/messages/client/MessageRelays.java create mode 100644 server/src/main/java/org/apache/druid/messages/package-info.java create mode 100644 server/src/main/java/org/apache/druid/messages/server/MessageRelayMonitor.java create mode 100644 server/src/main/java/org/apache/druid/messages/server/MessageRelayResource.java create mode 100644 server/src/main/java/org/apache/druid/messages/server/Outbox.java create mode 100644 server/src/main/java/org/apache/druid/messages/server/OutboxImpl.java create mode 100644 server/src/main/java/org/apache/druid/rpc/FixedServiceLocator.java delete mode 100644 server/src/main/java/org/apache/druid/rpc/FixedSetServiceLocator.java create mode 100644 server/src/test/java/org/apache/druid/messages/MessageBatchTest.java create mode 100644 server/src/test/java/org/apache/druid/messages/client/MessageRelayClientImplTest.java create mode 100644 server/src/test/java/org/apache/druid/messages/client/MessageRelaysTest.java create mode 100644 server/src/test/java/org/apache/druid/messages/server/OutboxImplTest.java rename server/src/test/java/org/apache/druid/rpc/{FixedSetServiceLocatorTest.java => FixedServiceLocatorTest.java} (56%) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/Dart.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/Dart.java new file mode 100644 index 000000000000..33e239161ffe --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/Dart.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart; + +import com.google.inject.BindingAnnotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Binding annotation for implements of interfaces that are Dart (MSQ-on-Broker-and-Historicals) focused. + */ +@Target({ElementType.FIELD, ElementType.PARAMETER, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@BindingAnnotation +public @interface Dart +{ +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/DartResourcePermissionMapper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/DartResourcePermissionMapper.java new file mode 100644 index 000000000000..038d1b56c72b --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/DartResourcePermissionMapper.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart; + +import com.google.common.collect.ImmutableList; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.msq.rpc.WorkerResource; +import org.apache.druid.server.security.Action; +import org.apache.druid.server.security.Resource; +import org.apache.druid.server.security.ResourceAction; + +import java.util.List; + +public class DartResourcePermissionMapper implements ResourcePermissionMapper +{ + /** + * Permissions for admin APIs in {@link DartWorkerResource} and {@link WorkerResource}. Note that queries from + * end users go through {@link DartSqlResource}, which wouldn't use these mappings. + */ + @Override + public List getAdminPermissions() + { + return ImmutableList.of( + new ResourceAction(Resource.STATE_RESOURCE, Action.READ), + new ResourceAction(Resource.STATE_RESOURCE, Action.WRITE) + ); + } + + /** + * Permissions for per-query APIs in {@link DartWorkerResource} and {@link WorkerResource}. Note that queries from + * end users go through {@link DartSqlResource}, which wouldn't use these mappings. + */ + @Override + public List getQueryPermissions(String queryId) + { + return getAdminPermissions(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerHolder.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerHolder.java new file mode 100644 index 000000000000..9644444dad24 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerHolder.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.common.base.Preconditions; +import org.apache.druid.msq.dart.worker.DartWorkerClient; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.QueryListener; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.error.WorkerFailedFault; +import org.apache.druid.server.security.AuthenticationResult; +import org.joda.time.DateTime; + +import java.util.concurrent.atomic.AtomicReference; + +/** + * Holder for {@link Controller}, stored in {@link DartControllerRegistry}. + */ +public class ControllerHolder +{ + public enum State + { + /** + * Query has been accepted, but not yet {@link Controller#run(QueryListener)}. + */ + ACCEPTED, + + /** + * Query has had {@link Controller#run(QueryListener)} called. + */ + RUNNING, + + /** + * Query has been canceled. + */ + CANCELED + } + + private final Controller controller; + private final ControllerContext controllerContext; + private final String sqlQueryId; + private final String sql; + private final AuthenticationResult authenticationResult; + private final DateTime startTime; + private final AtomicReference state = new AtomicReference<>(State.ACCEPTED); + + public ControllerHolder( + final Controller controller, + final ControllerContext controllerContext, + final String sqlQueryId, + final String sql, + final AuthenticationResult authenticationResult, + final DateTime startTime + ) + { + this.controller = Preconditions.checkNotNull(controller, "controller"); + this.controllerContext = controllerContext; + this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId"); + this.sql = sql; + this.authenticationResult = authenticationResult; + this.startTime = Preconditions.checkNotNull(startTime, "startTime"); + } + + public Controller getController() + { + return controller; + } + + public String getSqlQueryId() + { + return sqlQueryId; + } + + public String getSql() + { + return sql; + } + + public AuthenticationResult getAuthenticationResult() + { + return authenticationResult; + } + + public DateTime getStartTime() + { + return startTime; + } + + public State getState() + { + return state.get(); + } + + /** + * Call when a worker has gone offline. Closes its client and sends a {@link Controller#workerError} + * to the controller. + */ + public void workerOffline(final WorkerId workerId) + { + final String workerIdString = workerId.toString(); + + if (controllerContext instanceof DartControllerContext) { + // For DartControllerContext, newWorkerClient() returns the same instance every time. + // This will always be DartControllerContext in production; the instanceof check is here because certain + // tests use a different context class. + ((DartWorkerClient) controllerContext.newWorkerClient()).closeClient(workerId.getHostAndPort()); + } + + if (controller.hasWorker(workerIdString)) { + controller.workerError( + MSQErrorReport.fromFault( + workerIdString, + workerId.getHostAndPort(), + null, + new WorkerFailedFault(workerIdString, "Worker went offline") + ) + ); + } + } + + /** + * Places this holder into {@link State#CANCELED}. Calls {@link Controller#stop()} if it was previously in + * state {@link State#RUNNING}. + */ + public void cancel() + { + if (state.getAndSet(State.CANCELED) == State.RUNNING) { + controller.stop(); + } + } + + /** + * Calls {@link Controller#run(QueryListener)}, and returns true, if this holder was previously in state + * {@link State#ACCEPTED}. Otherwise returns false. + * + * @return whether {@link Controller#run(QueryListener)} was called. + */ + public boolean run(final QueryListener listener) throws Exception + { + if (state.compareAndSet(State.ACCEPTED, State.RUNNING)) { + controller.run(listener); + return true; + } else { + return false; + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerMessageListener.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerMessageListener.java new file mode 100644 index 000000000000..5cedd13baf0d --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/ControllerMessageListener.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.inject.Inject; +import org.apache.druid.messages.client.MessageListener; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.server.DruidNode; + +/** + * Listener for worker-to-controller messages. + * Also responsible for calling {@link Controller#workerError(MSQErrorReport)} when a worker server goes away. + */ +public class ControllerMessageListener implements MessageListener +{ + private final DartControllerRegistry controllerRegistry; + + @Inject + public ControllerMessageListener(final DartControllerRegistry controllerRegistry) + { + this.controllerRegistry = controllerRegistry; + } + + @Override + public void messageReceived(ControllerMessage message) + { + final ControllerHolder holder = controllerRegistry.get(message.getQueryId()); + if (holder != null) { + message.handle(holder.getController()); + } + } + + @Override + public void serverAdded(DruidNode node) + { + // Nothing to do. + } + + @Override + public void serverRemoved(DruidNode node) + { + for (final ControllerHolder holder : controllerRegistry.getAllHolders()) { + final Controller controller = holder.getController(); + final WorkerId workerId = WorkerId.fromDruidNode(node, controller.queryId()); + holder.workerOffline(workerId); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java new file mode 100644 index 000000000000..0248e66fd221 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContext.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.inject.Injector; +import org.apache.druid.client.BrokerServerView; +import org.apache.druid.error.DruidException; +import org.apache.druid.indexing.common.TaskLockType; +import org.apache.druid.indexing.common.actions.TaskActionClient; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.java.util.emitter.service.ServiceEmitter; +import org.apache.druid.java.util.emitter.service.ServiceMetricEvent; +import org.apache.druid.msq.dart.worker.DartWorkerClient; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.ControllerMemoryParameters; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.WorkerFailureListener; +import org.apache.druid.msq.exec.WorkerManager; +import org.apache.druid.msq.indexing.IndexerControllerContext; +import org.apache.druid.msq.indexing.MSQSpec; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; +import org.apache.druid.msq.input.InputSpecSlicer; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; +import org.apache.druid.msq.querykit.QueryKit; +import org.apache.druid.msq.querykit.QueryKitSpec; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.Query; +import org.apache.druid.query.QueryContext; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.coordination.DruidServerMetadata; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Dart implementation of {@link ControllerContext}. + * Each instance is scoped to a query. + */ +public class DartControllerContext implements ControllerContext +{ + /** + * Default for {@link ControllerQueryKernelConfig#getMaxConcurrentStages()}. + */ + public static final int DEFAULT_MAX_CONCURRENT_STAGES = 2; + + /** + * Default for {@link MultiStageQueryContext#getTargetPartitionsPerWorkerWithDefault(QueryContext, int)}. + */ + public static final int DEFAULT_TARGET_PARTITIONS_PER_WORKER = 1; + + /** + * Context parameter for maximum number of nonleaf workers. + */ + public static final String CTX_MAX_NON_LEAF_WORKER_COUNT = "maxNonLeafWorkers"; + + /** + * Default to scatter/gather style: fan in to a single worker after the leaf stage(s). + */ + public static final int DEFAULT_MAX_NON_LEAF_WORKER_COUNT = 1; + + private final Injector injector; + private final ObjectMapper jsonMapper; + private final DruidNode selfNode; + private final DartWorkerClient workerClient; + private final BrokerServerView serverView; + private final MemoryIntrospector memoryIntrospector; + private final ServiceMetricEvent.Builder metricBuilder; + private final ServiceEmitter emitter; + + public DartControllerContext( + final Injector injector, + final ObjectMapper jsonMapper, + final DruidNode selfNode, + final DartWorkerClient workerClient, + final MemoryIntrospector memoryIntrospector, + final BrokerServerView serverView, + final ServiceEmitter emitter + ) + { + this.injector = injector; + this.jsonMapper = jsonMapper; + this.selfNode = selfNode; + this.workerClient = workerClient; + this.serverView = serverView; + this.memoryIntrospector = memoryIntrospector; + this.metricBuilder = new ServiceMetricEvent.Builder(); + this.emitter = emitter; + } + + @Override + public ControllerQueryKernelConfig queryKernelConfig( + final String queryId, + final MSQSpec querySpec + ) + { + final List servers = serverView.getDruidServerMetadatas(); + + // Lock in the list of workers when creating the kernel config. There is a race here: the serverView itself is + // allowed to float. If a segment moves to a new server that isn't part of our list after the WorkerManager is + // created, we won't be able to find a valid server for certain segments. This isn't expected to be a problem, + // since the serverView is referenced shortly after the worker list is created. + final List workerIds = new ArrayList<>(servers.size()); + for (final DruidServerMetadata server : servers) { + workerIds.add(WorkerId.fromDruidServerMetadata(server, queryId).toString()); + } + + // Shuffle workerIds, so we don't bias towards specific servers when running multiple queries concurrently. For any + // given query, lower-numbered workers tend to do more work, because the controller prefers using lower-numbered + // workers when maxWorkerCount for a stage is less than the total number of workers. + Collections.shuffle(workerIds); + + final ControllerMemoryParameters memoryParameters = + ControllerMemoryParameters.createProductionInstance( + memoryIntrospector, + workerIds.size() + ); + + final int maxConcurrentStages = MultiStageQueryContext.getMaxConcurrentStagesWithDefault( + querySpec.getQuery().context(), + DEFAULT_MAX_CONCURRENT_STAGES + ); + + return ControllerQueryKernelConfig + .builder() + .controllerHost(selfNode.getHostAndPortToUse()) + .workerIds(workerIds) + .pipeline(maxConcurrentStages > 1) + .destination(TaskReportMSQDestination.instance()) + .maxConcurrentStages(maxConcurrentStages) + .maxRetainedPartitionSketchBytes(memoryParameters.getPartitionStatisticsMaxRetainedBytes()) + .workerContextMap(IndexerControllerContext.makeWorkerContextMap(querySpec, false, maxConcurrentStages)) + .build(); + } + + @Override + public ObjectMapper jsonMapper() + { + return jsonMapper; + } + + @Override + public Injector injector() + { + return injector; + } + + @Override + public void emitMetric(final String metric, final Number value) + { + emitter.emit(metricBuilder.setMetric(metric, value)); + } + + @Override + public DruidNode selfNode() + { + return selfNode; + } + + @Override + public InputSpecSlicer newTableInputSpecSlicer(WorkerManager workerManager) + { + return DartTableInputSpecSlicer.createFromWorkerIds(workerManager.getWorkerIds(), serverView); + } + + @Override + public TaskActionClient taskActionClient() + { + throw new UnsupportedOperationException(); + } + + @Override + public WorkerManager newWorkerManager( + String queryId, + MSQSpec querySpec, + ControllerQueryKernelConfig queryKernelConfig, + WorkerFailureListener workerFailureListener + ) + { + // We're ignoring WorkerFailureListener. Dart worker failures are routed into the controller by + // ControllerMessageListener, which receives a notification when a worker goes offline. + return new DartWorkerManager(queryKernelConfig.getWorkerIds(), workerClient); + } + + @Override + public DartWorkerClient newWorkerClient() + { + return workerClient; + } + + @Override + public void registerController(Controller controller, Closer closer) + { + closer.register(workerClient); + } + + @Override + public QueryKitSpec makeQueryKitSpec( + final QueryKit> queryKit, + final String queryId, + final MSQSpec querySpec, + final ControllerQueryKernelConfig queryKernelConfig + ) + { + final QueryContext queryContext = querySpec.getQuery().context(); + return new QueryKitSpec( + queryKit, + queryId, + queryKernelConfig.getWorkerIds().size(), + queryContext.getInt( + CTX_MAX_NON_LEAF_WORKER_COUNT, + DEFAULT_MAX_NON_LEAF_WORKER_COUNT + ), + MultiStageQueryContext.getTargetPartitionsPerWorkerWithDefault( + queryContext, + DEFAULT_TARGET_PARTITIONS_PER_WORKER + ) + ); + } + + @Override + public TaskLockType taskLockType() + { + throw DruidException.defensive("TaskLockType is not used with class[%s]", getClass().getName()); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactory.java new file mode 100644 index 000000000000..f58eb4bfa68d --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactory.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import org.apache.druid.msq.dart.controller.sql.DartQueryMaker; +import org.apache.druid.msq.exec.ControllerContext; + +/** + * Class for creating {@link ControllerContext} in {@link DartQueryMaker}. + */ +public interface DartControllerContextFactory +{ + ControllerContext newContext(String queryId); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactoryImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactoryImpl.java new file mode 100644 index 000000000000..8cefb6af7ece --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerContextFactoryImpl.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.inject.Inject; +import com.google.inject.Injector; +import org.apache.druid.client.BrokerServerView; +import org.apache.druid.guice.annotations.EscalatedGlobal; +import org.apache.druid.guice.annotations.Json; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.guice.annotations.Smile; +import org.apache.druid.java.util.emitter.service.ServiceEmitter; +import org.apache.druid.msq.dart.worker.DartWorkerClient; +import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.server.DruidNode; + +public class DartControllerContextFactoryImpl implements DartControllerContextFactory +{ + private final Injector injector; + private final ObjectMapper jsonMapper; + private final ObjectMapper smileMapper; + private final DruidNode selfNode; + private final ServiceClientFactory serviceClientFactory; + private final BrokerServerView serverView; + private final MemoryIntrospector memoryIntrospector; + private final ServiceEmitter emitter; + + @Inject + public DartControllerContextFactoryImpl( + final Injector injector, + @Json final ObjectMapper jsonMapper, + @Smile final ObjectMapper smileMapper, + @Self final DruidNode selfNode, + @EscalatedGlobal final ServiceClientFactory serviceClientFactory, + final MemoryIntrospector memoryIntrospector, + final BrokerServerView serverView, + final ServiceEmitter emitter + ) + { + this.injector = injector; + this.jsonMapper = jsonMapper; + this.smileMapper = smileMapper; + this.selfNode = selfNode; + this.serviceClientFactory = serviceClientFactory; + this.serverView = serverView; + this.memoryIntrospector = memoryIntrospector; + this.emitter = emitter; + } + + @Override + public ControllerContext newContext(final String queryId) + { + return new DartControllerContext( + injector, + jsonMapper, + selfNode, + new DartWorkerClient(queryId, serviceClientFactory, smileMapper, selfNode.getHostAndPortToUse()), + memoryIntrospector, + serverView, + emitter + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerRegistry.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerRegistry.java new file mode 100644 index 000000000000..847dbf759806 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartControllerRegistry.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import org.apache.druid.error.DruidException; +import org.apache.druid.msq.exec.Controller; + +import javax.annotation.Nullable; +import java.util.Collection; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Registry for actively-running {@link Controller}. + */ +public class DartControllerRegistry +{ + private final ConcurrentHashMap controllerMap = new ConcurrentHashMap<>(); + + /** + * Add a controller. Throws {@link DruidException} if a controller with the same {@link Controller#queryId()} is + * already registered. + */ + public void register(ControllerHolder holder) + { + if (controllerMap.putIfAbsent(holder.getController().queryId(), holder) != null) { + throw DruidException.defensive("Controller[%s] already registered", holder.getController().queryId()); + } + } + + /** + * Remove a controller from the registry. + */ + public void deregister(ControllerHolder holder) + { + // Remove only if the current mapping for the queryId is this specific controller. + controllerMap.remove(holder.getController().queryId(), holder); + } + + /** + * Return a specific controller holder, or null if it doesn't exist. + */ + @Nullable + public ControllerHolder get(final String queryId) + { + return controllerMap.get(queryId); + } + + /** + * Returns all actively-running {@link Controller}. + */ + public Collection getAllHolders() + { + return controllerMap.values(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelayFactoryImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelayFactoryImpl.java new file mode 100644 index 000000000000..7f16a37c9d72 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelayFactoryImpl.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.inject.Inject; +import org.apache.druid.guice.annotations.EscalatedGlobal; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.guice.annotations.Smile; +import org.apache.druid.messages.client.MessageRelay; +import org.apache.druid.messages.client.MessageRelayClientImpl; +import org.apache.druid.messages.client.MessageRelayFactory; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.rpc.FixedServiceLocator; +import org.apache.druid.rpc.ServiceClient; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.rpc.ServiceLocation; +import org.apache.druid.rpc.StandardRetryPolicy; +import org.apache.druid.server.DruidNode; + +/** + * Production implementation of {@link MessageRelayFactory}. + */ +public class DartMessageRelayFactoryImpl implements MessageRelayFactory +{ + private final String clientHost; + private final ControllerMessageListener messageListener; + private final ServiceClientFactory clientFactory; + private final String basePath; + private final ObjectMapper smileMapper; + + @Inject + public DartMessageRelayFactoryImpl( + @Self DruidNode selfNode, + @EscalatedGlobal ServiceClientFactory clientFactory, + @Smile ObjectMapper smileMapper, + ControllerMessageListener messageListener + ) + { + this.clientHost = selfNode.getHostAndPortToUse(); + this.messageListener = messageListener; + this.clientFactory = clientFactory; + this.smileMapper = smileMapper; + this.basePath = DartWorkerResource.PATH + "/relay"; + } + + @Override + public MessageRelay newRelay(DruidNode clientNode) + { + final ServiceLocation location = ServiceLocation.fromDruidNode(clientNode).withBasePath(basePath); + final ServiceClient client = clientFactory.makeClient( + clientNode.getHostAndPortToUse(), + new FixedServiceLocator(location), + StandardRetryPolicy.unlimited() + ); + + return new MessageRelay<>( + clientHost, + clientNode, + new MessageRelayClientImpl<>(client, smileMapper, ControllerMessage.class), + messageListener + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelays.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelays.java new file mode 100644 index 000000000000..23accd35ecbe --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartMessageRelays.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.messages.client.MessageRelayFactory; +import org.apache.druid.messages.client.MessageRelays; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; + +/** + * Specialized {@link MessageRelays} for Dart controllers. + */ +public class DartMessageRelays extends MessageRelays +{ + public DartMessageRelays( + final DruidNodeDiscoveryProvider discoveryProvider, + final MessageRelayFactory messageRelayFactory + ) + { + super(() -> discoveryProvider.getForNodeRole(NodeRole.HISTORICAL), messageRelayFactory); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicer.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicer.java new file mode 100644 index 000000000000..52ecccbc152f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicer.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.common.collect.FluentIterable; +import com.google.common.collect.ImmutableList; +import it.unimi.dsi.fastutil.objects.Object2IntMap; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import org.apache.druid.client.TimelineServerView; +import org.apache.druid.client.selector.QueryableDruidServer; +import org.apache.druid.client.selector.ServerSelector; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.JodaUtils; +import org.apache.druid.msq.dart.worker.DartQueryableSegment; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.exec.SegmentSource; +import org.apache.druid.msq.exec.WorkerManager; +import org.apache.druid.msq.input.InputSlice; +import org.apache.druid.msq.input.InputSpec; +import org.apache.druid.msq.input.InputSpecSlicer; +import org.apache.druid.msq.input.NilInputSlice; +import org.apache.druid.msq.input.table.RichSegmentDescriptor; +import org.apache.druid.msq.input.table.SegmentsInputSlice; +import org.apache.druid.msq.input.table.TableInputSpec; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.filter.DimFilterUtils; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.TimelineLookup; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Set; +import java.util.function.ToIntFunction; + +/** + * Slices {@link TableInputSpec} into {@link SegmentsInputSlice} for persistent servers using + * {@link TimelineServerView}. + */ +public class DartTableInputSpecSlicer implements InputSpecSlicer +{ + private static final int UNKNOWN = -1; + + /** + * Worker host:port -> worker number. This is the reverse of the mapping from {@link WorkerManager#getWorkerIds()}. + */ + private final Object2IntMap workerIdToNumber; + + /** + * Server view for identifying which segments exist and which servers (workers) have which segments. + */ + private final TimelineServerView serverView; + + DartTableInputSpecSlicer(final Object2IntMap workerIdToNumber, final TimelineServerView serverView) + { + this.workerIdToNumber = workerIdToNumber; + this.serverView = serverView; + } + + public static DartTableInputSpecSlicer createFromWorkerIds( + final List workerIds, + final TimelineServerView serverView + ) + { + final Object2IntMap reverseWorkers = new Object2IntOpenHashMap<>(); + reverseWorkers.defaultReturnValue(UNKNOWN); + + for (int i = 0; i < workerIds.size(); i++) { + reverseWorkers.put(WorkerId.fromString(workerIds.get(i)).getHostAndPort(), i); + } + + return new DartTableInputSpecSlicer(reverseWorkers, serverView); + } + + @Override + public boolean canSliceDynamic(final InputSpec inputSpec) + { + return false; + } + + @Override + public List sliceStatic(final InputSpec inputSpec, final int maxNumSlices) + { + final TableInputSpec tableInputSpec = (TableInputSpec) inputSpec; + final TimelineLookup timeline = + serverView.getTimeline(new TableDataSource(tableInputSpec.getDataSource()).getAnalysis()).orElse(null); + + if (timeline == null) { + return Collections.emptyList(); + } + + final Set prunedSegments = + findQueryableDataSegments( + tableInputSpec, + timeline, + serverSelector -> findWorkerForServerSelector(serverSelector, maxNumSlices) + ); + + final List> assignments = new ArrayList<>(maxNumSlices); + while (assignments.size() < maxNumSlices) { + assignments.add(null); + } + + int nextRoundRobinWorker = 0; + for (final DartQueryableSegment segment : prunedSegments) { + final int worker; + if (segment.getWorkerNumber() == UNKNOWN) { + // Segment is not available on any worker. Assign to some worker, round-robin. Today, that server will throw + // an error about the segment not being findable, but perhaps one day, it will be able to load the segment + // on demand. + worker = nextRoundRobinWorker; + nextRoundRobinWorker = (nextRoundRobinWorker + 1) % maxNumSlices; + } else { + worker = segment.getWorkerNumber(); + } + + if (assignments.get(worker) == null) { + assignments.set(worker, new ArrayList<>()); + } + + assignments.get(worker).add(segment); + } + + return makeSegmentSlices(tableInputSpec.getDataSource(), assignments); + } + + @Override + public List sliceDynamic( + final InputSpec inputSpec, + final int maxNumSlices, + final int maxFilesPerSlice, + final long maxBytesPerSlice + ) + { + throw new UnsupportedOperationException(); + } + + /** + * Return the worker ID that corresponds to a particular {@link ServerSelector}, or {@link #UNKNOWN} if none does. + * + * @param serverSelector the server selector + * @param maxNumSlices maximum number of worker IDs to use + */ + int findWorkerForServerSelector(final ServerSelector serverSelector, final int maxNumSlices) + { + final QueryableDruidServer server = serverSelector.pick(null); + + if (server == null) { + return UNKNOWN; + } + + final String serverHostAndPort = server.getServer().getHostAndPort(); + final int workerNumber = workerIdToNumber.getInt(serverHostAndPort); + + // The worker number may be UNKNOWN in a race condition, such as the set of Historicals changing while + // the query is being planned. I don't think it can be >= maxNumSlices, but if it is, treat it like UNKNOWN. + if (workerNumber != UNKNOWN && workerNumber < maxNumSlices) { + return workerNumber; + } else { + return UNKNOWN; + } + } + + /** + * Pull the list of {@link DataSegment} that we should query, along with a clipping interval for each one, and + * a worker to get it from. + */ + static Set findQueryableDataSegments( + final TableInputSpec tableInputSpec, + final TimelineLookup timeline, + final ToIntFunction toWorkersFunction + ) + { + final FluentIterable allSegments = + FluentIterable.from(JodaUtils.condenseIntervals(tableInputSpec.getIntervals())) + .transformAndConcat(timeline::lookup) + .transformAndConcat( + holder -> + FluentIterable + .from(holder.getObject()) + .filter(chunk -> shouldIncludeSegment(chunk.getObject())) + .transform(chunk -> { + final ServerSelector serverSelector = chunk.getObject(); + final DataSegment dataSegment = serverSelector.getSegment(); + final int worker = toWorkersFunction.applyAsInt(serverSelector); + return new DartQueryableSegment(dataSegment, holder.getInterval(), worker); + }) + .filter(segment -> !segment.getSegment().isTombstone()) + ); + + return DimFilterUtils.filterShards( + tableInputSpec.getFilter(), + tableInputSpec.getFilterFields(), + allSegments, + segment -> segment.getSegment().getShardSpec(), + new HashMap<>() + ); + } + + /** + * Create a list of {@link SegmentsInputSlice} and {@link NilInputSlice} assignments. + * + * @param dataSource datasource to read + * @param assignments list of assignment lists, one per slice + * + * @return a list of the same length as "assignments" + * + * @throws IllegalStateException if any provided segments do not match the provided datasource + */ + static List makeSegmentSlices( + final String dataSource, + final List> assignments + ) + { + final List retVal = new ArrayList<>(assignments.size()); + + for (final List assignment : assignments) { + if (assignment == null || assignment.isEmpty()) { + retVal.add(NilInputSlice.INSTANCE); + } else { + final List descriptors = new ArrayList<>(); + for (final DartQueryableSegment segment : assignment) { + if (!dataSource.equals(segment.getSegment().getDataSource())) { + throw new ISE("Expected dataSource[%s] but got[%s]", dataSource, segment.getSegment().getDataSource()); + } + + descriptors.add(toRichSegmentDescriptor(segment)); + } + + retVal.add(new SegmentsInputSlice(dataSource, descriptors, ImmutableList.of())); + } + } + + return retVal; + } + + /** + * Returns a {@link RichSegmentDescriptor}, which is used by {@link SegmentsInputSlice}. + */ + static RichSegmentDescriptor toRichSegmentDescriptor(final DartQueryableSegment segment) + { + return new RichSegmentDescriptor( + segment.getSegment().getInterval(), + segment.getInterval(), + segment.getSegment().getVersion(), + segment.getSegment().getShardSpec().getPartitionNum() + ); + } + + /** + * Whether to include a segment from the timeline. Segments are included if they are not tombstones, and are also not + * purely realtime segments. + */ + static boolean shouldIncludeSegment(final ServerSelector serverSelector) + { + if (serverSelector.getSegment().isTombstone()) { + return false; + } + + int numRealtimeServers = 0; + int numOtherServers = 0; + + for (final DruidServerMetadata server : serverSelector.getAllServers()) { + if (SegmentSource.REALTIME.getUsedServerTypes().contains(server.getType())) { + numRealtimeServers++; + } else { + numOtherServers++; + } + } + + return numOtherServers > 0 || (numOtherServers + numRealtimeServers == 0); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartWorkerManager.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartWorkerManager.java new file mode 100644 index 000000000000..54e163862d62 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/DartWorkerManager.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import it.unimi.dsi.fastutil.ints.Int2ObjectAVLTreeMap; +import it.unimi.dsi.fastutil.ints.Int2ObjectMap; +import it.unimi.dsi.fastutil.objects.Object2IntMap; +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.error.DruidException; +import org.apache.druid.indexer.TaskState; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.dart.worker.DartWorkerClient; +import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.WorkerClient; +import org.apache.druid.msq.exec.WorkerManager; +import org.apache.druid.msq.exec.WorkerStats; +import org.apache.druid.msq.indexing.WorkerCount; +import org.apache.druid.utils.CloseableUtils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Dart implementation of the {@link WorkerManager} returned by {@link ControllerContext#newWorkerManager}. + * + * This manager does not actually launch workers. The workers are housed on long-lived servers outside of this + * manager's control. This manager merely reports on their existence. + */ +public class DartWorkerManager implements WorkerManager +{ + private static final Logger log = new Logger(DartWorkerManager.class); + + private final List workerIds; + private final DartWorkerClient workerClient; + private final Object2IntMap workerIdToNumber; + private final AtomicReference state = new AtomicReference<>(State.NEW); + private final SettableFuture stopFuture = SettableFuture.create(); + + enum State + { + NEW, + STARTED, + STOPPED + } + + public DartWorkerManager( + final List workerIds, + final DartWorkerClient workerClient + ) + { + this.workerIds = workerIds; + this.workerClient = workerClient; + this.workerIdToNumber = new Object2IntOpenHashMap<>(); + this.workerIdToNumber.defaultReturnValue(UNKNOWN_WORKER_NUMBER); + + for (int i = 0; i < workerIds.size(); i++) { + workerIdToNumber.put(workerIds.get(i), i); + } + } + + @Override + public ListenableFuture start() + { + if (!state.compareAndSet(State.NEW, State.STARTED)) { + throw new ISE("Cannot start from state[%s]", state.get()); + } + + return stopFuture; + } + + @Override + public void launchWorkersIfNeeded(int workerCount) + { + // Nothing to do, just validate the count. + if (workerCount > workerIds.size()) { + throw DruidException.defensive( + "Desired workerCount[%s] must be less than or equal to actual workerCount[%s]", + workerCount, + workerIds.size() + ); + } + } + + @Override + public void waitForWorkers(Set workerNumbers) + { + // Nothing to wait for, just validate the numbers. + for (final int workerNumber : workerNumbers) { + if (workerNumber >= workerIds.size()) { + throw DruidException.defensive( + "Desired workerNumber[%s] must be less than workerCount[%s]", + workerNumber, + workerIds.size() + ); + } + } + } + + @Override + public List getWorkerIds() + { + return workerIds; + } + + @Override + public WorkerCount getWorkerCount() + { + return new WorkerCount(workerIds.size(), 0); + } + + @Override + public int getWorkerNumber(String workerId) + { + return workerIdToNumber.getInt(workerId); + } + + @Override + public boolean isWorkerActive(String workerId) + { + return workerIdToNumber.containsKey(workerId); + } + + @Override + public Map> getWorkerStats() + { + final Int2ObjectMap> retVal = new Int2ObjectAVLTreeMap<>(); + + for (int i = 0; i < workerIds.size(); i++) { + retVal.put(i, Collections.singletonList(new WorkerStats(workerIds.get(i), TaskState.RUNNING, -1, -1))); + } + + return retVal; + } + + /** + * Stop method. Possibly signals workers to stop, but does not actually wait for them to exit. + * + * If "interrupt" is false, does nothing special (other than setting {@link #stopFuture}). The assumption is that + * a previous call to {@link WorkerClient#postFinish} would have caused the worker to exit. + * + * If "interrupt" is true, sends {@link DartWorkerClient#stopWorker(String)} to workers to stop the current query ID. + * + * @param interrupt whether to interrupt currently-running work + */ + @Override + public void stop(boolean interrupt) + { + if (state.compareAndSet(State.STARTED, State.STOPPED)) { + final List> futures = new ArrayList<>(); + + // Send stop commands to all workers. This ensures they exit promptly, and do not get left in a zombie state. + // For this reason, the workerClient uses an unlimited retry policy. If a stop command is lost, a worker + // could get stuck in a zombie state without its controller. This state would persist until the server that + // ran the controller shuts down or restarts. At that time, the listener in DartWorkerRunner.BrokerListener calls + // "controllerFailed()" on the Worker, and the zombie worker would exit. + + for (final String workerId : workerIds) { + futures.add(workerClient.stopWorker(workerId)); + } + + // Block until messages are acknowledged, or until the worker we're communicating with has failed. + + try { + FutureUtils.getUnchecked(Futures.successfulAsList(futures), false); + } + catch (Throwable ignored) { + // Suppress errors. + } + + CloseableUtils.closeAndSuppressExceptions(workerClient, e -> log.warn(e, "Failed to close workerClient")); + stopFuture.set(null); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartQueryInfo.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartQueryInfo.java new file mode 100644 index 000000000000..e5f3abb894e1 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartQueryInfo.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.msq.util.MSQTaskQueryMakerUtils; +import org.apache.druid.query.QueryContexts; +import org.joda.time.DateTime; + +import java.util.Objects; + +/** + * Class included in {@link GetQueriesResponse}. + */ +public class DartQueryInfo +{ + private final String sqlQueryId; + private final String dartQueryId; + private final String sql; + private final String authenticator; + private final String identity; + private final DateTime startTime; + private final String state; + + @JsonCreator + public DartQueryInfo( + @JsonProperty("sqlQueryId") final String sqlQueryId, + @JsonProperty("dartQueryId") final String dartQueryId, + @JsonProperty("sql") final String sql, + @JsonProperty("authenticator") final String authenticator, + @JsonProperty("identity") final String identity, + @JsonProperty("startTime") final DateTime startTime, + @JsonProperty("state") final String state + ) + { + this.sqlQueryId = Preconditions.checkNotNull(sqlQueryId, "sqlQueryId"); + this.dartQueryId = Preconditions.checkNotNull(dartQueryId, "dartQueryId"); + this.sql = sql; + this.authenticator = authenticator; + this.identity = identity; + this.startTime = startTime; + this.state = state; + } + + public static DartQueryInfo fromControllerHolder(final ControllerHolder holder) + { + return new DartQueryInfo( + holder.getSqlQueryId(), + holder.getController().queryId(), + MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(holder.getSql()), + holder.getAuthenticationResult().getAuthenticatedBy(), + holder.getAuthenticationResult().getIdentity(), + holder.getStartTime(), + holder.getState().toString() + ); + } + + /** + * The {@link QueryContexts#CTX_SQL_QUERY_ID} provided by the user, or generated by the system. + */ + @JsonProperty + public String getSqlQueryId() + { + return sqlQueryId; + } + + /** + * Dart query ID generated by the system. Globally unique. + */ + @JsonProperty + public String getDartQueryId() + { + return dartQueryId; + } + + /** + * SQL string for this query, masked using {@link MSQTaskQueryMakerUtils#maskSensitiveJsonKeys(String)}. + */ + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public String getSql() + { + return sql; + } + + /** + * Authenticator that authenticated the identity from {@link #getIdentity()}. + */ + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public String getAuthenticator() + { + return authenticator; + } + + /** + * User that issued this query. + */ + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public String getIdentity() + { + return identity; + } + + /** + * Time this query was started. + */ + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_NULL) + public DateTime getStartTime() + { + return startTime; + } + + @JsonProperty + public String getState() + { + return state; + } + + /** + * Returns a copy of this instance with {@link #getAuthenticator()} and {@link #getIdentity()} nulled. + */ + public DartQueryInfo withoutAuthenticationResult() + { + return new DartQueryInfo(sqlQueryId, dartQueryId, sql, null, null, startTime, state); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DartQueryInfo that = (DartQueryInfo) o; + return Objects.equals(sqlQueryId, that.sqlQueryId) + && Objects.equals(dartQueryId, that.dartQueryId) + && Objects.equals(sql, that.sql) + && Objects.equals(authenticator, that.authenticator) + && Objects.equals(identity, that.identity) + && Objects.equals(startTime, that.startTime) + && Objects.equals(state, that.state); + } + + @Override + public int hashCode() + { + return Objects.hash(sqlQueryId, dartQueryId, sql, authenticator, identity, startTime, state); + } + + @Override + public String toString() + { + return "DartQueryInfo{" + + "sqlQueryId='" + sqlQueryId + '\'' + + ", dartQueryId='" + dartQueryId + '\'' + + ", sql='" + sql + '\'' + + ", authenticator='" + authenticator + '\'' + + ", identity='" + identity + '\'' + + ", startTime=" + startTime + + ", state=" + state + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartSqlResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartSqlResource.java new file mode 100644 index 000000000000..37e9f1051318 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/DartSqlResource.java @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.Futures; +import com.google.inject.Inject; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.msq.dart.controller.DartControllerRegistry; +import org.apache.druid.msq.dart.controller.sql.DartSqlClients; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; +import org.apache.druid.query.DefaultQueryConfig; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.ResponseContextConfig; +import org.apache.druid.server.initialization.ServerConfig; +import org.apache.druid.server.security.Access; +import org.apache.druid.server.security.Action; +import org.apache.druid.server.security.AuthenticationResult; +import org.apache.druid.server.security.AuthorizationUtils; +import org.apache.druid.server.security.AuthorizerMapper; +import org.apache.druid.server.security.Resource; +import org.apache.druid.server.security.ResourceAction; +import org.apache.druid.sql.HttpStatement; +import org.apache.druid.sql.SqlLifecycleManager; +import org.apache.druid.sql.SqlStatementFactory; +import org.apache.druid.sql.http.SqlQuery; +import org.apache.druid.sql.http.SqlResource; + +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.Consumes; +import javax.ws.rs.DELETE; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; +import java.util.stream.Collectors; + +/** + * Resource for Dart queries. API-compatible with {@link SqlResource}, so clients can be pointed from + * {@code /druid/v2/sql/} to {@code /druid/v2/sql/dart/} without code changes. + */ +@Path(DartSqlResource.PATH + '/') +public class DartSqlResource extends SqlResource +{ + public static final String PATH = "/druid/v2/sql/dart"; + + private static final Logger log = new Logger(DartSqlResource.class); + + private final DartControllerRegistry controllerRegistry; + private final SqlLifecycleManager sqlLifecycleManager; + private final DartSqlClients sqlClients; + private final AuthorizerMapper authorizerMapper; + private final DefaultQueryConfig dartQueryConfig; + + @Inject + public DartSqlResource( + final ObjectMapper jsonMapper, + final AuthorizerMapper authorizerMapper, + @Dart final SqlStatementFactory sqlStatementFactory, + final DartControllerRegistry controllerRegistry, + final SqlLifecycleManager sqlLifecycleManager, + final DartSqlClients sqlClients, + final ServerConfig serverConfig, + final ResponseContextConfig responseContextConfig, + @Self final DruidNode selfNode, + @Dart final DefaultQueryConfig dartQueryConfig + ) + { + super( + jsonMapper, + authorizerMapper, + sqlStatementFactory, + sqlLifecycleManager, + serverConfig, + responseContextConfig, + selfNode + ); + this.controllerRegistry = controllerRegistry; + this.sqlLifecycleManager = sqlLifecycleManager; + this.sqlClients = sqlClients; + this.authorizerMapper = authorizerMapper; + this.dartQueryConfig = dartQueryConfig; + } + + /** + * API that allows callers to check if this resource is installed without actually issuing a query. If installed, + * this call returns 200 OK. If not installed, callers get 404 Not Found. + */ + @GET + @Path("/enabled") + @Produces(MediaType.APPLICATION_JSON) + public Response doGetEnabled(@Context final HttpServletRequest request) + { + AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(request); + return Response.ok(ImmutableMap.of("enabled", true)).build(); + } + + /** + * API to list all running queries. + * + * @param selfOnly if true, return queries running on this server. If false, return queries running on all servers. + * @param req http request + */ + @GET + @Produces(MediaType.APPLICATION_JSON) + public GetQueriesResponse doGetRunningQueries( + @QueryParam("selfOnly") final String selfOnly, + @Context final HttpServletRequest req + ) + { + final AuthenticationResult authenticationResult = AuthorizationUtils.authenticationResultFromRequest(req); + final Access stateReadAccess = AuthorizationUtils.authorizeAllResourceActions( + authenticationResult, + Collections.singletonList(new ResourceAction(Resource.STATE_RESOURCE, Action.READ)), + authorizerMapper + ); + + final List queries = + controllerRegistry.getAllHolders() + .stream() + .map(DartQueryInfo::fromControllerHolder) + .sorted(Comparator.comparing(DartQueryInfo::getStartTime)) + .collect(Collectors.toList()); + + // Add queries from all other servers, if "selfOnly" is not set. + if (selfOnly == null) { + final List otherQueries = FutureUtils.getUnchecked( + Futures.successfulAsList( + Iterables.transform(sqlClients.getAllClients(), client -> client.getRunningQueries(true))), + true + ); + + for (final GetQueriesResponse response : otherQueries) { + if (response != null) { + queries.addAll(response.getQueries()); + } + } + } + + final GetQueriesResponse response; + if (stateReadAccess.isAllowed()) { + // User can READ STATE, so they can see all running queries, as well as authentication details. + response = new GetQueriesResponse(queries); + } else { + // User cannot READ STATE, so they can see only their own queries, without authentication details. + response = new GetQueriesResponse( + queries.stream() + .filter( + query -> + authenticationResult.getAuthenticatedBy() != null + && authenticationResult.getIdentity() != null + && Objects.equals(authenticationResult.getAuthenticatedBy(), query.getAuthenticator()) + && Objects.equals(authenticationResult.getIdentity(), query.getIdentity())) + .map(DartQueryInfo::withoutAuthenticationResult) + .collect(Collectors.toList()) + ); + } + + AuthorizationUtils.setRequestAuthorizationAttributeIfNeeded(req); + return response; + } + + /** + * API to issue a query. + */ + @POST + @Produces(MediaType.APPLICATION_JSON) + @Consumes(MediaType.APPLICATION_JSON) + @Override + public Response doPost( + final SqlQuery sqlQuery, + @Context final HttpServletRequest req + ) + { + final Map context = new HashMap<>(sqlQuery.getContext()); + + // Default context keys from dartQueryConfig. + for (Map.Entry entry : dartQueryConfig.getContext().entrySet()) { + context.putIfAbsent(entry.getKey(), entry.getValue()); + } + + // Dart queryId must be globally unique; cannot use user-provided sqlQueryId or queryId. + final String dartQueryId = UUID.randomUUID().toString(); + context.put(DartSqlEngine.CTX_DART_QUERY_ID, dartQueryId); + + return super.doPost(sqlQuery.withOverridenContext(context), req); + } + + /** + * API to cancel a query. + */ + @DELETE + @Path("{id}") + @Produces(MediaType.APPLICATION_JSON) + @Override + public Response cancelQuery( + @PathParam("id") String sqlQueryId, + @Context final HttpServletRequest req + ) + { + log.debug("Received cancel request for query[%s]", sqlQueryId); + + List cancelables = sqlLifecycleManager.getAll(sqlQueryId); + if (cancelables.isEmpty()) { + return Response.status(Response.Status.NOT_FOUND).build(); + } + + final Access access = authorizeCancellation(req, cancelables); + + if (access.isAllowed()) { + sqlLifecycleManager.removeAll(sqlQueryId, cancelables); + + // Don't call cancel() on the cancelables. That just cancels native queries, which is useless here. Instead, + // get the controller and stop it. + boolean found = false; + for (SqlLifecycleManager.Cancelable cancelable : cancelables) { + final HttpStatement stmt = (HttpStatement) cancelable; + final Object dartQueryId = stmt.context().get(DartSqlEngine.CTX_DART_QUERY_ID); + if (dartQueryId instanceof String) { + final ControllerHolder holder = controllerRegistry.get((String) dartQueryId); + if (holder != null) { + found = true; + holder.cancel(); + } + } else { + log.warn( + "%s[%s] for query[%s] is not a string, cannot cancel.", + DartSqlEngine.CTX_DART_QUERY_ID, + dartQueryId, + sqlQueryId + ); + } + } + + return Response.status(found ? Response.Status.ACCEPTED : Response.Status.NOT_FOUND).build(); + } else { + return Response.status(Response.Status.FORBIDDEN).build(); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponse.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponse.java new file mode 100644 index 000000000000..2d1f87f860c5 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponse.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; + +/** + * Class returned by {@link DartSqlResource#doGetRunningQueries}, the "list all queries" API. + */ +public class GetQueriesResponse +{ + private final List queries; + + @JsonCreator + public GetQueriesResponse(@JsonProperty("queries") List queries) + { + this.queries = queries; + } + + @JsonProperty + public List getQueries() + { + return queries; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GetQueriesResponse response = (GetQueriesResponse) o; + return Objects.equals(queries, response.queries); + } + + @Override + public int hashCode() + { + return Objects.hashCode(queries); + } + + @Override + public String toString() + { + return "GetQueriesResponse{" + + "queries=" + queries + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ControllerMessage.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ControllerMessage.java new file mode 100644 index 000000000000..454e23bbc9c1 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ControllerMessage.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonSubTypes; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.apache.druid.msq.dart.worker.DartControllerClient; +import org.apache.druid.msq.exec.Controller; + +/** + * Messages sent from worker to controller by {@link DartControllerClient}. + */ +@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") +@JsonSubTypes({ + @JsonSubTypes.Type(value = PartialKeyStatistics.class, name = "partialKeyStatistics"), + @JsonSubTypes.Type(value = DoneReadingInput.class, name = "doneReadingInput"), + @JsonSubTypes.Type(value = ResultsComplete.class, name = "resultsComplete"), + @JsonSubTypes.Type(value = WorkerError.class, name = "workerError"), + @JsonSubTypes.Type(value = WorkerWarning.class, name = "workerWarning") +}) +public interface ControllerMessage +{ + /** + * Query ID, to identify the controller that is being contacted. + */ + String getQueryId(); + + /** + * Handler for this message, which calls an appropriate method on {@link Controller}. + */ + void handle(Controller controller); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/DoneReadingInput.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/DoneReadingInput.java new file mode 100644 index 000000000000..e74e5a0d1bb7 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/DoneReadingInput.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.kernel.StageId; + +import java.util.Objects; + +/** + * Message for {@link ControllerClient#postDoneReadingInput}. + */ +public class DoneReadingInput implements ControllerMessage +{ + private final StageId stageId; + private final int workerNumber; + + @JsonCreator + public DoneReadingInput( + @JsonProperty("stage") final StageId stageId, + @JsonProperty("worker") final int workerNumber + ) + { + this.stageId = Preconditions.checkNotNull(stageId, "stageId"); + this.workerNumber = workerNumber; + } + + @Override + public String getQueryId() + { + return stageId.getQueryId(); + } + + @JsonProperty("stage") + public StageId getStageId() + { + return stageId; + } + + @JsonProperty("worker") + public int getWorkerNumber() + { + return workerNumber; + } + + @Override + public void handle(Controller controller) + { + controller.doneReadingInput(stageId.getStageNumber(), workerNumber); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DoneReadingInput that = (DoneReadingInput) o; + return workerNumber == that.workerNumber + && Objects.equals(stageId, that.stageId); + } + + @Override + public int hashCode() + { + return Objects.hash(stageId, workerNumber); + } + + @Override + public String toString() + { + return "DoneReadingInput{" + + "stageId=" + stageId + + ", workerNumber=" + workerNumber + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/PartialKeyStatistics.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/PartialKeyStatistics.java new file mode 100644 index 000000000000..1aa3bcb040e4 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/PartialKeyStatistics.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; + +import java.util.Objects; + +/** + * Message for {@link ControllerClient#postPartialKeyStatistics}. + */ +public class PartialKeyStatistics implements ControllerMessage +{ + private final StageId stageId; + private final int workerNumber; + private final PartialKeyStatisticsInformation payload; + + @JsonCreator + public PartialKeyStatistics( + @JsonProperty("stage") final StageId stageId, + @JsonProperty("worker") final int workerNumber, + @JsonProperty("payload") final PartialKeyStatisticsInformation payload + ) + { + this.stageId = Preconditions.checkNotNull(stageId, "stageId"); + this.workerNumber = workerNumber; + this.payload = payload; + } + + @Override + public String getQueryId() + { + return stageId.getQueryId(); + } + + @JsonProperty("stage") + public StageId getStageId() + { + return stageId; + } + + @JsonProperty("worker") + public int getWorkerNumber() + { + return workerNumber; + } + + @JsonProperty + public PartialKeyStatisticsInformation getPayload() + { + return payload; + } + + + @Override + public void handle(Controller controller) + { + controller.updatePartialKeyStatisticsInformation( + stageId.getStageNumber(), + workerNumber, + payload + ); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PartialKeyStatistics that = (PartialKeyStatistics) o; + return workerNumber == that.workerNumber + && Objects.equals(stageId, that.stageId) + && Objects.equals(payload, that.payload); + } + + @Override + public int hashCode() + { + return Objects.hash(stageId, workerNumber, payload); + } + + @Override + public String toString() + { + return "PartialKeyStatistics{" + + "stageId=" + stageId + + ", workerNumber=" + workerNumber + + ", payload=" + payload + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ResultsComplete.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ResultsComplete.java new file mode 100644 index 000000000000..58822a357265 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/ResultsComplete.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.kernel.StageId; + +import javax.annotation.Nullable; +import java.util.Objects; + +/** + * Message for {@link ControllerClient#postResultsComplete}. + */ +public class ResultsComplete implements ControllerMessage +{ + private final StageId stageId; + private final int workerNumber; + + @Nullable + private final Object resultObject; + + @JsonCreator + public ResultsComplete( + @JsonProperty("stage") final StageId stageId, + @JsonProperty("worker") final int workerNumber, + @Nullable @JsonProperty("result") final Object resultObject + ) + { + this.stageId = Preconditions.checkNotNull(stageId, "stageId"); + this.workerNumber = workerNumber; + this.resultObject = resultObject; + } + + @Override + public String getQueryId() + { + return stageId.getQueryId(); + } + + @JsonProperty("stage") + public StageId getStageId() + { + return stageId; + } + + @JsonProperty("worker") + public int getWorkerNumber() + { + return workerNumber; + } + + @Nullable + @JsonProperty("result") + @JsonInclude(JsonInclude.Include.NON_NULL) + public Object getResultObject() + { + return resultObject; + } + + @Override + public void handle(Controller controller) + { + controller.resultsComplete(stageId.getQueryId(), stageId.getStageNumber(), workerNumber, resultObject); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ResultsComplete that = (ResultsComplete) o; + return workerNumber == that.workerNumber + && Objects.equals(stageId, that.stageId) + && Objects.equals(resultObject, that.resultObject); + } + + @Override + public int hashCode() + { + return Objects.hash(stageId, workerNumber, resultObject); + } + + @Override + public String toString() + { + return "ResultsComplete{" + + "stageId=" + stageId + + ", workerNumber=" + workerNumber + + ", resultObject=" + resultObject + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerError.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerError.java new file mode 100644 index 000000000000..b89cfb356a36 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerError.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.indexing.error.MSQErrorReport; + +import java.util.Objects; + +/** + * Message for {@link ControllerClient#postWorkerError}. + */ +public class WorkerError implements ControllerMessage +{ + private final String queryId; + private final MSQErrorReport errorWrapper; + + @JsonCreator + public WorkerError( + @JsonProperty("queryId") String queryId, + @JsonProperty("error") MSQErrorReport errorWrapper + ) + { + this.queryId = Preconditions.checkNotNull(queryId, "queryId"); + this.errorWrapper = Preconditions.checkNotNull(errorWrapper, "error"); + } + + @Override + @JsonProperty + public String getQueryId() + { + return queryId; + } + + @JsonProperty("error") + public MSQErrorReport getErrorWrapper() + { + return errorWrapper; + } + + @Override + public void handle(Controller controller) + { + controller.workerError(errorWrapper); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WorkerError that = (WorkerError) o; + return Objects.equals(queryId, that.queryId) + && Objects.equals(errorWrapper, that.errorWrapper); + } + + @Override + public int hashCode() + { + return Objects.hash(queryId, errorWrapper); + } + + @Override + public String toString() + { + return "WorkerError{" + + "queryId='" + queryId + '\'' + + ", errorWrapper=" + errorWrapper + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerWarning.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerWarning.java new file mode 100644 index 000000000000..aa2ff6643131 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/messages/WorkerWarning.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Preconditions; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.indexing.error.MSQErrorReport; + +import java.util.List; +import java.util.Objects; + +/** + * Message for {@link ControllerClient#postWorkerWarning}. + */ +public class WorkerWarning implements ControllerMessage +{ + private final String queryId; + private final List errorWrappers; + + @JsonCreator + public WorkerWarning( + @JsonProperty("queryId") String queryId, + @JsonProperty("errors") List errorWrappers + ) + { + this.queryId = Preconditions.checkNotNull(queryId, "queryId"); + this.errorWrappers = Preconditions.checkNotNull(errorWrappers, "error"); + } + + @Override + @JsonProperty + public String getQueryId() + { + return queryId; + } + + @JsonProperty("errors") + public List getErrorWrappers() + { + return errorWrappers; + } + + @Override + public void handle(Controller controller) + { + controller.workerWarning(errorWrappers); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WorkerWarning that = (WorkerWarning) o; + return Objects.equals(queryId, that.queryId) && Objects.equals(errorWrappers, that.errorWrappers); + } + + @Override + public int hashCode() + { + return Objects.hash(queryId, errorWrappers); + } + + @Override + public String toString() + { + return "WorkerWarning{" + + "queryId='" + queryId + '\'' + + ", errorWrappers=" + errorWrappers + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartQueryMaker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartQueryMaker.java new file mode 100644 index 000000000000..37ed936a1173 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartQueryMaker.java @@ -0,0 +1,484 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.google.common.base.Throwables; +import com.google.common.collect.Iterators; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.io.LimitedOutputStream; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.Either; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.guava.BaseSequence; +import org.apache.druid.java.util.common.guava.Sequence; +import org.apache.druid.java.util.common.jackson.JacksonUtils; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.msq.dart.controller.DartControllerContextFactory; +import org.apache.druid.msq.dart.controller.DartControllerRegistry; +import org.apache.druid.msq.dart.guice.DartControllerConfig; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.exec.ControllerContext; +import org.apache.druid.msq.exec.ControllerImpl; +import org.apache.druid.msq.exec.QueryListener; +import org.apache.druid.msq.exec.ResultsContext; +import org.apache.druid.msq.indexing.MSQSpec; +import org.apache.druid.msq.indexing.TaskReportQueryListener; +import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination; +import org.apache.druid.msq.indexing.error.CanceledFault; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.report.MSQResultsReport; +import org.apache.druid.msq.indexing.report.MSQStatusReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; +import org.apache.druid.msq.sql.MSQTaskQueryMaker; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.server.QueryResponse; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.rel.DruidQuery; +import org.apache.druid.sql.calcite.run.QueryMaker; +import org.apache.druid.sql.calcite.run.SqlResults; + +import javax.annotation.Nullable; +import java.io.ByteArrayOutputStream; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.stream.Collectors; + +/** + * SQL {@link QueryMaker}. Executes queries in two ways, depending on whether the user asked for a full report. + * + * When including a full report, the controller runs in the SQL planning thread (typically an HTTP thread) using + * the method {@link #runWithReport(ControllerHolder)}. The entire response is buffered in memory, up to + * {@link DartControllerConfig#getMaxQueryReportSize()}. + * + * When not including a full report, the controller runs in {@link #controllerExecutor} and results are streamed + * back to the user through {@link ResultIterator}. There is no limit to the size of the returned results. + */ +public class DartQueryMaker implements QueryMaker +{ + private static final Logger log = new Logger(DartQueryMaker.class); + + private final List> fieldMapping; + private final DartControllerContextFactory controllerContextFactory; + private final PlannerContext plannerContext; + + /** + * Controller registry, used to register and remove controllers as they start and finish. + */ + private final DartControllerRegistry controllerRegistry; + + /** + * Controller config. + */ + private final DartControllerConfig controllerConfig; + + /** + * Executor for {@link #runWithoutReport(ControllerHolder)}. Number of thread is equal to + * {@link DartControllerConfig#getConcurrentQueries()}, which limits the number of concurrent controllers. + */ + private final ExecutorService controllerExecutor; + + public DartQueryMaker( + List> fieldMapping, + DartControllerContextFactory controllerContextFactory, + PlannerContext plannerContext, + DartControllerRegistry controllerRegistry, + DartControllerConfig controllerConfig, + ExecutorService controllerExecutor + ) + { + this.fieldMapping = fieldMapping; + this.controllerContextFactory = controllerContextFactory; + this.plannerContext = plannerContext; + this.controllerRegistry = controllerRegistry; + this.controllerConfig = controllerConfig; + this.controllerExecutor = controllerExecutor; + } + + @Override + public QueryResponse runQuery(DruidQuery druidQuery) + { + final MSQSpec querySpec = MSQTaskQueryMaker.makeQuerySpec( + null, + druidQuery, + fieldMapping, + plannerContext, + null // Only used for DML, which this isn't + ); + final List> types = + MSQTaskQueryMaker.getTypes(druidQuery, fieldMapping, plannerContext); + + final String dartQueryId = druidQuery.getQuery().context().getString(DartSqlEngine.CTX_DART_QUERY_ID); + final ControllerContext controllerContext = controllerContextFactory.newContext(dartQueryId); + final ControllerImpl controller = new ControllerImpl( + dartQueryId, + querySpec, + new ResultsContext( + types.stream().map(p -> p.lhs).collect(Collectors.toList()), + SqlResults.Context.fromPlannerContext(plannerContext) + ), + controllerContext + ); + + final ControllerHolder controllerHolder = new ControllerHolder( + controller, + controllerContext, + plannerContext.getSqlQueryId(), + plannerContext.getSql(), + plannerContext.getAuthenticationResult(), + DateTimes.nowUtc() + ); + + final boolean fullReport = druidQuery.getQuery().context().getBoolean( + DartSqlEngine.CTX_FULL_REPORT, + DartSqlEngine.CTX_FULL_REPORT_DEFAULT + ); + + // Register controller before submitting anything to controllerExeuctor, so it shows up in + // "active controllers" lists. + controllerRegistry.register(controllerHolder); + + try { + // runWithReport, runWithoutReport are responsible for calling controllerRegistry.deregister(controllerHolder) + // when their work is done. + final Sequence results = + fullReport ? runWithReport(controllerHolder) : runWithoutReport(controllerHolder); + return QueryResponse.withEmptyContext(results); + } + catch (Throwable e) { + // Error while calling runWithReport or runWithoutReport. Deregister controller immediately. + controllerRegistry.deregister(controllerHolder); + throw e; + } + } + + /** + * Run a query and return the full report, buffered in memory up to + * {@link DartControllerConfig#getMaxQueryReportSize()}. + * + * Arranges for {@link DartControllerRegistry#deregister(ControllerHolder)} to be called upon completion (either + * success or failure). + */ + private Sequence runWithReport(final ControllerHolder controllerHolder) + { + final Future> reportFuture; + + // Run in controllerExecutor. Control doesn't really *need* to be moved to another thread, but we have to + // use the controllerExecutor anyway, to ensure we respect the concurrentQueries configuration. + reportFuture = controllerExecutor.submit(() -> { + final String threadName = Thread.currentThread().getName(); + + try { + Thread.currentThread().setName(nameThread(plannerContext)); + + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final TaskReportQueryListener queryListener = new TaskReportQueryListener( + TaskReportMSQDestination.instance(), + () -> new LimitedOutputStream( + baos, + controllerConfig.getMaxQueryReportSize(), + limit -> StringUtils.format( + "maxQueryReportSize[%,d] exceeded. " + + "Try limiting the result set for your query, or run it with %s[false]", + limit, + DartSqlEngine.CTX_FULL_REPORT + ) + ), + plannerContext.getJsonMapper(), + controllerHolder.getController().queryId(), + Collections.emptyMap() + ); + + if (controllerHolder.run(queryListener)) { + return plannerContext.getJsonMapper() + .readValue(baos.toByteArray(), JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT); + } else { + // Controller was canceled before it ran. + throw MSQErrorReport + .fromFault(controllerHolder.getController().queryId(), null, null, CanceledFault.INSTANCE) + .toDruidException(); + } + } + finally { + controllerRegistry.deregister(controllerHolder); + Thread.currentThread().setName(threadName); + } + }); + + // Return a sequence that reads one row (the report) from reportFuture. + return new BaseSequence<>( + new BaseSequence.IteratorMaker>() + { + @Override + public Iterator make() + { + try { + return Iterators.singletonIterator(new Object[]{reportFuture.get()}); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + catch (ExecutionException e) { + // Unwrap ExecutionExceptions, so errors such as DruidException are serialized properly. + Throwables.throwIfUnchecked(e.getCause()); + throw new RuntimeException(e.getCause()); + } + } + + @Override + public void cleanup(Iterator iterFromMake) + { + // Nothing to do. + } + } + ); + } + + /** + * Run a query and return the results only, streamed back using {@link ResultIteratorMaker}. + * + * Arranges for {@link DartControllerRegistry#deregister(ControllerHolder)} to be called upon completion (either + * success or failure). + */ + private Sequence runWithoutReport(final ControllerHolder controllerHolder) + { + return new BaseSequence<>(new ResultIteratorMaker(controllerHolder)); + } + + /** + * Generate a name for a thread in {@link #controllerExecutor}. + */ + private String nameThread(final PlannerContext plannerContext) + { + return StringUtils.format( + "%s-sqlQueryId[%s]-queryId[%s]", + Thread.currentThread().getName(), + plannerContext.getSqlQueryId(), + plannerContext.queryContext().get(DartSqlEngine.CTX_DART_QUERY_ID) + ); + } + + /** + * Helper for {@link #runWithoutReport(ControllerHolder)}. + */ + class ResultIteratorMaker implements BaseSequence.IteratorMaker + { + private final ControllerHolder controllerHolder; + private final ResultIterator resultIterator = new ResultIterator(); + private boolean made; + + public ResultIteratorMaker(ControllerHolder holder) + { + this.controllerHolder = holder; + submitController(); + } + + /** + * Submits the controller to the executor in the constructor, and remove it from the registry when the + * future resolves. + */ + private void submitController() + { + controllerExecutor.submit(() -> { + final Controller controller = controllerHolder.getController(); + final String threadName = Thread.currentThread().getName(); + + try { + Thread.currentThread().setName(nameThread(plannerContext)); + + if (!controllerHolder.run(resultIterator)) { + // Controller was canceled before it ran. Push a cancellation error to the resultIterator, so the sequence + // returned by "runWithoutReport" can resolve. + resultIterator.pushError( + MSQErrorReport.fromFault(controllerHolder.getController().queryId(), null, null, CanceledFault.INSTANCE) + .toDruidException() + ); + } + } + catch (Exception e) { + log.warn( + e, + "Controller failed for sqlQueryId[%s], controllerHost[%s]", + plannerContext.getSqlQueryId(), + controller.queryId() + ); + } + finally { + controllerRegistry.deregister(controllerHolder); + Thread.currentThread().setName(threadName); + } + }); + } + + @Override + public ResultIterator make() + { + if (made) { + throw new ISE("Cannot call make() more than once"); + } + + made = true; + return resultIterator; + } + + @Override + public void cleanup(final ResultIterator iterFromMake) + { + if (!iterFromMake.complete) { + controllerHolder.cancel(); + } + } + } + + /** + * Helper for {@link ResultIteratorMaker}, which is in turn a helper for {@link #runWithoutReport(ControllerHolder)}. + */ + static class ResultIterator implements Iterator, QueryListener + { + /** + * Number of rows to buffer from {@link #onResultRow(Object[])}. + */ + private static final int BUFFER_SIZE = 128; + + /** + * Empty optional signifies results are complete. + */ + private final BlockingQueue> rowBuffer = new ArrayBlockingQueue<>(BUFFER_SIZE); + + /** + * Only accessed by {@link Iterator} methods, so no need to be thread-safe. + */ + @Nullable + private Either current; + + private volatile boolean complete; + + @Override + public boolean hasNext() + { + return populateAndReturnCurrent().isPresent(); + } + + @Override + public Object[] next() + { + final Object[] retVal = populateAndReturnCurrent().orElseThrow(NoSuchElementException::new); + current = null; + return retVal; + } + + private Optional populateAndReturnCurrent() + { + if (current == null) { + try { + current = rowBuffer.take(); + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + if (current.isValue()) { + return Optional.ofNullable(current.valueOrThrow()); + } else { + // Don't use valueOrThrow to throw errors; here we *don't* want the wrapping in RuntimeException + // that Either.valueOrThrow does. We want the original DruidException to be propagated to the user, if + // there is one. + final Throwable e = current.error(); + Throwables.throwIfUnchecked(e); + throw new RuntimeException(e); + } + } + + @Override + public boolean readResults() + { + return !complete; + } + + @Override + public void onResultsStart( + final List signature, + @Nullable final List sqlTypeNames + ) + { + // Nothing to do. + } + + @Override + public boolean onResultRow(Object[] row) + { + try { + rowBuffer.put(Either.value(row)); + return !complete; + } + catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + @Override + public void onResultsComplete() + { + // Nothing to do. + } + + @Override + public void onQueryComplete(MSQTaskReportPayload report) + { + try { + complete = true; + + final MSQStatusReport statusReport = report.getStatus(); + + if (statusReport.getStatus().isSuccess()) { + rowBuffer.put(Either.value(null)); + } else { + pushError(statusReport.getErrorReport().toDruidException()); + } + } + catch (InterruptedException e) { + // Can't fix this by pushing an error, because the rowBuffer isn't accepting new entries. + // Give up, allow controllerHolder.run() to fail. + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + public void pushError(final Throwable e) throws InterruptedException + { + rowBuffer.put(Either.error(e)); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClient.java new file mode 100644 index 000000000000..447da229d05e --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClient.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.msq.dart.controller.http.GetQueriesResponse; + +import javax.servlet.http.HttpServletRequest; + +/** + * Client for the {@link DartSqlResource} resource. + */ +public interface DartSqlClient +{ + /** + * Get information about all currently-running queries on this server. + * + * @param selfOnly true if only queries from this server should be returned; false if queries from all servers + * should be returned + * + * @see DartSqlResource#doGetRunningQueries(String, HttpServletRequest) the server side + */ + ListenableFuture getRunningQueries(boolean selfOnly); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactory.java new file mode 100644 index 000000000000..879cabe6945f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactory.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import org.apache.druid.server.DruidNode; + +/** + * Generates {@link DartSqlClient} given a target Broker node. + */ +public interface DartSqlClientFactory +{ + DartSqlClient makeClient(DruidNode node); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactoryImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactoryImpl.java new file mode 100644 index 000000000000..c2355a43e31a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientFactoryImpl.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.inject.Inject; +import org.apache.druid.guice.annotations.EscalatedGlobal; +import org.apache.druid.guice.annotations.Json; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.rpc.FixedServiceLocator; +import org.apache.druid.rpc.ServiceClient; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.rpc.ServiceLocation; +import org.apache.druid.rpc.StandardRetryPolicy; +import org.apache.druid.server.DruidNode; + +/** + * Production implementation of {@link DartSqlClientFactory}. + */ +public class DartSqlClientFactoryImpl implements DartSqlClientFactory +{ + private final ServiceClientFactory clientFactory; + private final ObjectMapper jsonMapper; + + @Inject + public DartSqlClientFactoryImpl( + @EscalatedGlobal final ServiceClientFactory clientFactory, + @Json final ObjectMapper jsonMapper + ) + { + this.clientFactory = clientFactory; + this.jsonMapper = jsonMapper; + } + + @Override + public DartSqlClient makeClient(DruidNode node) + { + final ServiceClient client = clientFactory.makeClient( + StringUtils.format("%s[dart-sql]", node.getHostAndPortToUse()), + new FixedServiceLocator(ServiceLocation.fromDruidNode(node).withBasePath(DartSqlResource.PATH)), + StandardRetryPolicy.noRetries() + ); + + return new DartSqlClientImpl(client, jsonMapper); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImpl.java new file mode 100644 index 000000000000..aebf7e4b90fa --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImpl.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.java.util.common.jackson.JacksonUtils; +import org.apache.druid.java.util.http.client.response.BytesFullResponseHandler; +import org.apache.druid.msq.dart.controller.http.GetQueriesResponse; +import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.rpc.ServiceClient; +import org.jboss.netty.handler.codec.http.HttpMethod; + +/** + * Production implementation of {@link DartSqlClient}. + */ +public class DartSqlClientImpl implements DartSqlClient +{ + private final ServiceClient client; + private final ObjectMapper jsonMapper; + + public DartSqlClientImpl(final ServiceClient client, final ObjectMapper jsonMapper) + { + this.client = client; + this.jsonMapper = jsonMapper; + } + + @Override + public ListenableFuture getRunningQueries(final boolean selfOnly) + { + return FutureUtils.transform( + client.asyncRequest( + new RequestBuilder(HttpMethod.GET, selfOnly ? "/?selfOnly" : "/"), + new BytesFullResponseHandler() + ), + holder -> JacksonUtils.readValue(jsonMapper, holder.getContent(), GetQueriesResponse.class) + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClients.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClients.java new file mode 100644 index 000000000000..733f69ee4bf9 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlClients.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.google.common.collect.ImmutableList; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import com.google.inject.Inject; +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.guice.ManageLifecycle; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.java.util.common.lifecycle.LifecycleStart; +import org.apache.druid.java.util.common.lifecycle.LifecycleStop; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.server.DruidNode; + +import javax.servlet.http.HttpServletRequest; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Keeps {@link DartSqlClient} for all servers except ourselves. Currently the purpose of this is to power + * the "get all queries" API at {@link DartSqlResource#doGetRunningQueries(String, HttpServletRequest)}. + */ +@ManageLifecycle +public class DartSqlClients implements DruidNodeDiscovery.Listener +{ + @GuardedBy("clients") + private final Map clients = new HashMap<>(); + private final DruidNode selfNode; + private final DruidNodeDiscoveryProvider discoveryProvider; + private final DartSqlClientFactory clientFactory; + + private volatile DruidNodeDiscovery discovery; + + @Inject + public DartSqlClients( + @Self DruidNode selfNode, + DruidNodeDiscoveryProvider discoveryProvider, + DartSqlClientFactory clientFactory + ) + { + this.selfNode = selfNode; + this.discoveryProvider = discoveryProvider; + this.clientFactory = clientFactory; + } + + @LifecycleStart + public void start() + { + discovery = discoveryProvider.getForNodeRole(NodeRole.BROKER); + discovery.registerListener(this); + } + + public List getAllClients() + { + synchronized (clients) { + return ImmutableList.copyOf(clients.values()); + } + } + + @Override + public void nodesAdded(final Collection nodes) + { + synchronized (clients) { + for (final DiscoveryDruidNode node : nodes) { + final DruidNode druidNode = node.getDruidNode(); + if (!selfNode.equals(druidNode)) { + clients.computeIfAbsent(druidNode, clientFactory::makeClient); + } + } + } + } + + @Override + public void nodesRemoved(final Collection nodes) + { + synchronized (clients) { + for (final DiscoveryDruidNode node : nodes) { + clients.remove(node.getDruidNode()); + } + } + } + + @LifecycleStop + public void stop() + { + if (discovery != null) { + discovery.removeListener(this); + discovery = null; + } + + synchronized (clients) { + clients.clear(); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlEngine.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlEngine.java new file mode 100644 index 000000000000..28587e0e791a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/controller/sql/DartSqlEngine.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.google.common.collect.ImmutableList; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.msq.dart.controller.DartControllerContextFactory; +import org.apache.druid.msq.dart.controller.DartControllerRegistry; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.msq.dart.guice.DartControllerConfig; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.sql.MSQTaskSqlEngine; +import org.apache.druid.query.BaseQuery; +import org.apache.druid.query.QueryContext; +import org.apache.druid.query.QueryContexts; +import org.apache.druid.sql.SqlLifecycleManager; +import org.apache.druid.sql.calcite.planner.Calcites; +import org.apache.druid.sql.calcite.planner.PlannerContext; +import org.apache.druid.sql.calcite.run.EngineFeature; +import org.apache.druid.sql.calcite.run.QueryMaker; +import org.apache.druid.sql.calcite.run.SqlEngine; +import org.apache.druid.sql.calcite.run.SqlEngines; +import org.apache.druid.sql.destination.IngestDestination; + +import java.util.Map; +import java.util.concurrent.ExecutorService; + +public class DartSqlEngine implements SqlEngine +{ + private static final String NAME = "msq-dart"; + + /** + * Dart queryId must be globally unique, so we cannot use the user-provided {@link QueryContexts#CTX_SQL_QUERY_ID} + * or {@link BaseQuery#QUERY_ID}. Instead we generate a UUID in {@link DartSqlResource#doPost}, overriding whatever + * the user may have provided. This becomes the {@link Controller#queryId()}. + * + * The user-provided {@link QueryContexts#CTX_SQL_QUERY_ID} is still registered with the {@link SqlLifecycleManager} + * for purposes of query cancellation. + * + * The user-provided {@link BaseQuery#QUERY_ID} is ignored. + */ + public static final String CTX_DART_QUERY_ID = "dartQueryId"; + public static final String CTX_FULL_REPORT = "fullReport"; + public static final boolean CTX_FULL_REPORT_DEFAULT = false; + + private final DartControllerContextFactory controllerContextFactory; + private final DartControllerRegistry controllerRegistry; + private final DartControllerConfig controllerConfig; + private final ExecutorService controllerExecutor; + + public DartSqlEngine( + DartControllerContextFactory controllerContextFactory, + DartControllerRegistry controllerRegistry, + DartControllerConfig controllerConfig, + ExecutorService controllerExecutor + ) + { + this.controllerContextFactory = controllerContextFactory; + this.controllerRegistry = controllerRegistry; + this.controllerConfig = controllerConfig; + this.controllerExecutor = controllerExecutor; + } + + @Override + public String name() + { + return NAME; + } + + @Override + public boolean featureAvailable(EngineFeature feature) + { + switch (feature) { + case CAN_SELECT: + case SCAN_ORDER_BY_NON_TIME: + case SCAN_NEEDS_SIGNATURE: + case WINDOW_FUNCTIONS: + case WINDOW_LEAF_OPERATOR: + case UNNEST: + return true; + + case CAN_INSERT: + case CAN_REPLACE: + case READ_EXTERNAL_DATA: + case ALLOW_BINDABLE_PLAN: + case ALLOW_BROADCAST_RIGHTY_JOIN: + case ALLOW_TOP_LEVEL_UNION_ALL: + case TIMESERIES_QUERY: + case TOPN_QUERY: + case TIME_BOUNDARY_QUERY: + case GROUPING_SETS: + case GROUPBY_IMPLICITLY_SORTS: + return false; + + default: + throw new IAE("Unrecognized feature: %s", feature); + } + } + + @Override + public void validateContext(Map queryContext) + { + SqlEngines.validateNoSpecialContextKeys(queryContext, MSQTaskSqlEngine.SYSTEM_CONTEXT_PARAMETERS); + } + + @Override + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) + { + if (QueryContext.of(queryContext).getBoolean(CTX_FULL_REPORT, CTX_FULL_REPORT_DEFAULT)) { + return typeFactory.createStructType( + ImmutableList.of( + Calcites.createSqlType(typeFactory, SqlTypeName.VARCHAR) + ), + ImmutableList.of(CTX_FULL_REPORT) + ); + } else { + return validatedRowType; + } + } + + @Override + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) + { + // Defensive, because we expect this method will not be called without the CAN_INSERT and CAN_REPLACE features. + throw DruidException.defensive("Cannot execute DML commands with engine[%s]", name()); + } + + @Override + public QueryMaker buildQueryMakerForSelect(RelRoot relRoot, PlannerContext plannerContext) + { + return new DartQueryMaker( + relRoot.fields, + controllerContextFactory, + plannerContext, + controllerRegistry, + controllerConfig, + controllerExecutor + ); + } + + @Override + public QueryMaker buildQueryMakerForInsert( + IngestDestination destination, + RelRoot relRoot, + PlannerContext plannerContext + ) + { + // Defensive, because we expect this method will not be called without the CAN_INSERT and CAN_REPLACE features. + throw DruidException.defensive("Cannot execute DML commands with engine[%s]", name()); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerConfig.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerConfig.java new file mode 100644 index 000000000000..25094f44a79a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerConfig.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.fasterxml.jackson.annotation.JsonProperty; + +/** + * Runtime configuration for controllers (which run on Brokers). + */ +public class DartControllerConfig +{ + @JsonProperty("concurrentQueries") + private int concurrentQueries = 1; + + @JsonProperty("maxQueryReportSize") + private int maxQueryReportSize = 100_000_000; + + public int getConcurrentQueries() + { + return concurrentQueries; + } + + public int getMaxQueryReportSize() + { + return maxQueryReportSize; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerMemoryManagementModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerMemoryManagementModule.java new file mode 100644 index 000000000000..95f110ec88be --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerMemoryManagementModule.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.google.inject.Binder; +import com.google.inject.Provides; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.MemoryIntrospectorImpl; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.utils.JvmUtils; + +/** + * Memory management module for Brokers. + */ +@LoadScope(roles = {NodeRole.BROKER_JSON_NAME}) +public class DartControllerMemoryManagementModule implements DruidModule +{ + /** + * Allocate up to 15% of memory for the MSQ framework. This accounts for additional overhead due to native queries, + * the segment timeline, and lookups (which aren't accounted for by our {@link MemoryIntrospector}). + */ + public static final double USABLE_MEMORY_FRACTION = 0.15; + + @Override + public void configure(Binder binder) + { + // Nothing to do. + } + + @Provides + public MemoryIntrospector createMemoryIntrospector( + final DruidProcessingConfig processingConfig, + final DartControllerConfig controllerConfig + ) + { + return new MemoryIntrospectorImpl( + JvmUtils.getRuntimeInfo().getMaxHeapSizeBytes(), + USABLE_MEMORY_FRACTION, + controllerConfig.getConcurrentQueries(), + processingConfig.getNumThreads(), + null + ); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerModule.java new file mode 100644 index 000000000000..8a4b73bc9b0f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartControllerModule.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.google.inject.Binder; +import com.google.inject.Inject; +import com.google.inject.Module; +import com.google.inject.Provides; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.guice.Jerseys; +import org.apache.druid.guice.JsonConfigProvider; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.LifecycleModule; +import org.apache.druid.guice.ManageLifecycle; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.DartResourcePermissionMapper; +import org.apache.druid.msq.dart.controller.ControllerMessageListener; +import org.apache.druid.msq.dart.controller.DartControllerContextFactory; +import org.apache.druid.msq.dart.controller.DartControllerContextFactoryImpl; +import org.apache.druid.msq.dart.controller.DartControllerRegistry; +import org.apache.druid.msq.dart.controller.DartMessageRelayFactoryImpl; +import org.apache.druid.msq.dart.controller.DartMessageRelays; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.msq.dart.controller.sql.DartSqlClientFactory; +import org.apache.druid.msq.dart.controller.sql.DartSqlClientFactoryImpl; +import org.apache.druid.msq.dart.controller.sql.DartSqlClients; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.query.DefaultQueryConfig; +import org.apache.druid.sql.SqlStatementFactory; +import org.apache.druid.sql.SqlToolbox; + +import java.util.Properties; + +/** + * Primary module for Brokers. Checks {@link DartModules#isDartEnabled(Properties)} before installing itself. + */ +@LoadScope(roles = NodeRole.BROKER_JSON_NAME) +public class DartControllerModule implements DruidModule +{ + @Inject + private Properties properties; + + @Override + public void configure(Binder binder) + { + if (DartModules.isDartEnabled(properties)) { + binder.install(new ActualModule()); + } + } + + public static class ActualModule implements Module + { + @Override + public void configure(Binder binder) + { + JsonConfigProvider.bind(binder, DartModules.DART_PROPERTY_BASE + ".controller", DartControllerConfig.class); + JsonConfigProvider.bind(binder, DartModules.DART_PROPERTY_BASE + ".query", DefaultQueryConfig.class, Dart.class); + + Jerseys.addResource(binder, DartSqlResource.class); + + LifecycleModule.register(binder, DartSqlClients.class); + LifecycleModule.register(binder, DartMessageRelays.class); + + binder.bind(ControllerMessageListener.class).in(LazySingleton.class); + binder.bind(DartControllerRegistry.class).in(LazySingleton.class); + binder.bind(DartMessageRelayFactoryImpl.class).in(LazySingleton.class); + binder.bind(DartControllerContextFactory.class) + .to(DartControllerContextFactoryImpl.class) + .in(LazySingleton.class); + binder.bind(DartSqlClientFactory.class) + .to(DartSqlClientFactoryImpl.class) + .in(LazySingleton.class); + binder.bind(ResourcePermissionMapper.class) + .annotatedWith(Dart.class) + .to(DartResourcePermissionMapper.class); + } + + @Provides + @Dart + @LazySingleton + public SqlStatementFactory makeSqlStatementFactory(final DartSqlEngine engine, final SqlToolbox toolbox) + { + return new SqlStatementFactory(toolbox.withEngine(engine)); + } + + @Provides + @ManageLifecycle + public DartMessageRelays makeMessageRelays( + final DruidNodeDiscoveryProvider discoveryProvider, + final DartMessageRelayFactoryImpl messageRelayFactory + ) + { + return new DartMessageRelays(discoveryProvider, messageRelayFactory); + } + + @Provides + @LazySingleton + public DartSqlEngine makeSqlEngine( + DartControllerContextFactory controllerContextFactory, + DartControllerRegistry controllerRegistry, + DartControllerConfig controllerConfig + ) + { + return new DartSqlEngine( + controllerContextFactory, + controllerRegistry, + controllerConfig, + Execs.multiThreaded(controllerConfig.getConcurrentQueries(), "dart-controller-%s") + ); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartModules.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartModules.java new file mode 100644 index 000000000000..a8e1a1b65e69 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartModules.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import java.util.Properties; + +/** + * Common utilities for Dart Guice modules. + */ +public class DartModules +{ + public static final String DART_PROPERTY_BASE = "druid.msq.dart"; + public static final String DART_ENABLED_PROPERTY = DART_PROPERTY_BASE + ".enabled"; + public static final String DART_ENABLED_DEFAULT = String.valueOf(false); + + public static boolean isDartEnabled(final Properties properties) + { + return Boolean.parseBoolean(properties.getProperty(DART_ENABLED_PROPERTY, DART_ENABLED_DEFAULT)); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerConfig.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerConfig.java new file mode 100644 index 000000000000..f7322a1af92c --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerConfig.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.druid.msq.exec.MemoryIntrospector; + +/** + * Runtime configuration for workers (which run on Historicals). + */ +public class DartWorkerConfig +{ + /** + * By default, allocate up to 35% of memory for the MSQ framework. This accounts for additional overhead due to + * native queries, and lookups (which aren't accounted for by the Dart {@link MemoryIntrospector}). + */ + private static final double DEFAULT_HEAP_FRACTION = 0.35; + + public static final int AUTO = -1; + + @JsonProperty("concurrentQueries") + private int concurrentQueries = AUTO; + + @JsonProperty("heapFraction") + private double heapFraction = DEFAULT_HEAP_FRACTION; + + public int getConcurrentQueries() + { + return concurrentQueries; + } + + public double getHeapFraction() + { + return heapFraction; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerMemoryManagementModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerMemoryManagementModule.java new file mode 100644 index 000000000000..9f51a65152a1 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerMemoryManagementModule.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.google.inject.Binder; +import com.google.inject.Provides; +import org.apache.druid.collections.BlockingPool; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.error.DruidException; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.guice.annotations.Merging; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.worker.DartProcessingBuffersProvider; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.MemoryIntrospectorImpl; +import org.apache.druid.msq.exec.ProcessingBuffersProvider; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.utils.JvmUtils; + +import java.nio.ByteBuffer; + +/** + * Memory management module for Historicals. + */ +@LoadScope(roles = {NodeRole.HISTORICAL_JSON_NAME}) +public class DartWorkerMemoryManagementModule implements DruidModule +{ + @Override + public void configure(Binder binder) + { + // Nothing to do. + } + + @Provides + public MemoryIntrospector createMemoryIntrospector( + final DartWorkerConfig workerConfig, + final DruidProcessingConfig druidProcessingConfig + ) + { + return new MemoryIntrospectorImpl( + JvmUtils.getRuntimeInfo().getMaxHeapSizeBytes(), + workerConfig.getHeapFraction(), + computeConcurrentQueries(workerConfig, druidProcessingConfig), + druidProcessingConfig.getNumThreads(), + null + ); + } + + @Provides + @Dart + @LazySingleton + public ProcessingBuffersProvider createProcessingBuffersProvider( + @Merging final BlockingPool mergeBufferPool, + final DruidProcessingConfig processingConfig + ) + { + return new DartProcessingBuffersProvider(mergeBufferPool, processingConfig.getNumThreads()); + } + + private static int computeConcurrentQueries( + final DartWorkerConfig workerConfig, + final DruidProcessingConfig processingConfig + ) + { + if (workerConfig.getConcurrentQueries() == DartWorkerConfig.AUTO) { + return processingConfig.getNumMergeBuffers(); + } else if (workerConfig.getConcurrentQueries() < 0) { + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build("concurrentQueries[%s] must be positive or -1", workerConfig.getConcurrentQueries()); + } else if (workerConfig.getConcurrentQueries() > processingConfig.getNumMergeBuffers()) { + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build( + "concurrentQueries[%s] must be less than numMergeBuffers[%s]", + workerConfig.getConcurrentQueries(), + processingConfig.getNumMergeBuffers() + ); + } else { + return workerConfig.getConcurrentQueries(); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerModule.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerModule.java new file mode 100644 index 000000000000..15bc0e652994 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/guice/DartWorkerModule.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.guice; + +import com.google.inject.Binder; +import com.google.inject.Inject; +import com.google.inject.Key; +import com.google.inject.Module; +import com.google.inject.Provides; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.guice.Jerseys; +import org.apache.druid.guice.JsonConfigProvider; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.LifecycleModule; +import org.apache.druid.guice.ManageLifecycle; +import org.apache.druid.guice.ManageLifecycleAnnouncements; +import org.apache.druid.guice.annotations.LoadScope; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.initialization.DruidModule; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.messages.server.MessageRelayMonitor; +import org.apache.druid.messages.server.MessageRelayResource; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.messages.server.OutboxImpl; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.DartResourcePermissionMapper; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.worker.DartDataSegmentProvider; +import org.apache.druid.msq.dart.worker.DartWorkerFactory; +import org.apache.druid.msq.dart.worker.DartWorkerFactoryImpl; +import org.apache.druid.msq.dart.worker.DartWorkerRunner; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.security.AuthorizerMapper; + +import java.io.File; +import java.util.Properties; +import java.util.concurrent.ExecutorService; + +/** + * Primary module for workers. Checks {@link DartModules#isDartEnabled(Properties)} before installing itself. + */ +@LoadScope(roles = NodeRole.HISTORICAL_JSON_NAME) +public class DartWorkerModule implements DruidModule +{ + @Inject + private Properties properties; + + @Override + public void configure(Binder binder) + { + if (DartModules.isDartEnabled(properties)) { + binder.install(new ActualModule()); + } + } + + public static class ActualModule implements Module + { + @Override + public void configure(Binder binder) + { + JsonConfigProvider.bind(binder, DartModules.DART_PROPERTY_BASE + ".worker", DartWorkerConfig.class); + Jerseys.addResource(binder, DartWorkerResource.class); + LifecycleModule.register(binder, DartWorkerRunner.class); + LifecycleModule.registerKey(binder, Key.get(MessageRelayMonitor.class, Dart.class)); + + binder.bind(DartWorkerFactory.class) + .to(DartWorkerFactoryImpl.class) + .in(LazySingleton.class); + + binder.bind(DataSegmentProvider.class) + .annotatedWith(Dart.class) + .to(DartDataSegmentProvider.class) + .in(LazySingleton.class); + + binder.bind(ResourcePermissionMapper.class) + .annotatedWith(Dart.class) + .to(DartResourcePermissionMapper.class); + } + + @Provides + @ManageLifecycle + public DartWorkerRunner createWorkerRunner( + @Self final DruidNode selfNode, + final DartWorkerFactory workerFactory, + final DruidNodeDiscoveryProvider discoveryProvider, + final DruidProcessingConfig processingConfig, + @Dart final ResourcePermissionMapper permissionMapper, + final MemoryIntrospector memoryIntrospector, + final AuthorizerMapper authorizerMapper + ) + { + final ExecutorService exec = Execs.multiThreaded(memoryIntrospector.numTasksInJvm(), "dart–worker-%s"); + final File baseTempDir = + new File(processingConfig.getTmpDir(), StringUtils.format("dart_%s", selfNode.getPortToUse())); + return new DartWorkerRunner( + workerFactory, + exec, + discoveryProvider, + permissionMapper, + authorizerMapper, + baseTempDir + ); + } + + @Provides + @Dart + public MessageRelayMonitor createMessageRelayMonitor( + final DruidNodeDiscoveryProvider discoveryProvider, + final Outbox outbox + ) + { + return new MessageRelayMonitor(discoveryProvider, outbox, NodeRole.BROKER); + } + + /** + * Create an {@link Outbox}. + * + * This is {@link ManageLifecycleAnnouncements} scoped so {@link OutboxImpl#stop()} gets called before attempting + * to shut down the Jetty server. If this doesn't happen, then server shutdown is delayed by however long it takes + * any currently-in-flight {@link MessageRelayResource#httpGetMessagesFromOutbox} to resolve. + */ + @Provides + @ManageLifecycleAnnouncements + public Outbox createOutbox() + { + return new OutboxImpl<>(); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartControllerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartControllerClient.java new file mode 100644 index 000000000000..23d83d005497 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartControllerClient.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import org.apache.druid.common.guava.FutureBox; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.error.DruidException; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.controller.messages.DoneReadingInput; +import org.apache.druid.msq.dart.controller.messages.PartialKeyStatistics; +import org.apache.druid.msq.dart.controller.messages.ResultsComplete; +import org.apache.druid.msq.dart.controller.messages.WorkerError; +import org.apache.druid.msq.dart.controller.messages.WorkerWarning; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; + +import javax.annotation.Nullable; +import java.util.List; + +/** + * Implementation of {@link ControllerClient} that uses an {@link Outbox} to send {@link ControllerMessage} + * to a controller. + */ +public class DartControllerClient implements ControllerClient +{ + private final Outbox outbox; + private final String queryId; + private final String controllerHost; + + /** + * Currently-outstanding futures. These are tracked so they can be canceled in {@link #close()}. + */ + private final FutureBox futureBox = new FutureBox(); + + public DartControllerClient( + final Outbox outbox, + final String queryId, + final String controllerHost + ) + { + this.outbox = outbox; + this.queryId = queryId; + this.controllerHost = controllerHost; + } + + @Override + public void postPartialKeyStatistics( + final StageId stageId, + final int workerNumber, + final PartialKeyStatisticsInformation partialKeyStatisticsInformation + ) + { + validateStage(stageId); + sendMessage(new PartialKeyStatistics(stageId, workerNumber, partialKeyStatisticsInformation)); + } + + @Override + public void postDoneReadingInput(StageId stageId, int workerNumber) + { + validateStage(stageId); + sendMessage(new DoneReadingInput(stageId, workerNumber)); + } + + @Override + public void postResultsComplete(StageId stageId, int workerNumber, @Nullable Object resultObject) + { + validateStage(stageId); + sendMessage(new ResultsComplete(stageId, workerNumber, resultObject)); + } + + @Override + public void postWorkerError(MSQErrorReport errorWrapper) + { + sendMessage(new WorkerError(queryId, errorWrapper)); + } + + @Override + public void postWorkerWarning(List errorWrappers) + { + sendMessage(new WorkerWarning(queryId, errorWrappers)); + } + + @Override + public void postCounters(String workerId, CounterSnapshotsTree snapshotsTree) + { + // Do nothing. Live counters are not sent to the controller in this mode. + } + + @Override + public List getWorkerIds() + { + // Workers are set in advance through the WorkOrder, so this method isn't used. + throw new UnsupportedOperationException(); + } + + @Override + public void close() + { + // Cancel any pending futures. + futureBox.close(); + } + + private void sendMessage(final ControllerMessage message) + { + FutureUtils.getUnchecked(futureBox.register(outbox.sendMessage(controllerHost, message)), true); + } + + /** + * Validate that a {@link StageId} has the expected query ID. + */ + private void validateStage(final StageId stageId) + { + if (!stageId.getQueryId().equals(queryId)) { + throw DruidException.defensive( + "Expected queryId[%s] but got queryId[%s], stageNumber[%s]", + queryId, + stageId.getQueryId(), + stageId.getStageNumber() + ); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataSegmentProvider.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataSegmentProvider.java new file mode 100644 index 000000000000..0e8a38af90a3 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartDataSegmentProvider.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.google.inject.Inject; +import org.apache.druid.collections.ReferenceCountingResourceHolder; +import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.msq.counters.ChannelCounters; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.segment.CompleteSegment; +import org.apache.druid.segment.PhysicalSegmentInspector; +import org.apache.druid.segment.ReferenceCountingSegment; +import org.apache.druid.server.SegmentManager; +import org.apache.druid.timeline.SegmentId; +import org.apache.druid.timeline.VersionedIntervalTimeline; +import org.apache.druid.timeline.partition.PartitionChunk; + +import java.io.Closeable; +import java.util.Optional; +import java.util.function.Supplier; + +/** + * Implementation of {@link DataSegmentProvider} that uses locally-cached segments from a {@link SegmentManager}. + */ +public class DartDataSegmentProvider implements DataSegmentProvider +{ + private final SegmentManager segmentManager; + + @Inject + public DartDataSegmentProvider(SegmentManager segmentManager) + { + this.segmentManager = segmentManager; + } + + @Override + public Supplier> fetchSegment( + SegmentId segmentId, + ChannelCounters channelCounters, + boolean isReindex + ) + { + if (isReindex) { + throw DruidException.defensive("Got isReindex[%s], expected false", isReindex); + } + + return () -> { + final Optional> timeline = + segmentManager.getTimeline(new TableDataSource(segmentId.getDataSource()).getAnalysis()); + + if (!timeline.isPresent()) { + throw segmentNotFound(segmentId); + } + + final PartitionChunk chunk = + timeline.get().findChunk( + segmentId.getInterval(), + segmentId.getVersion(), + segmentId.getPartitionNum() + ); + + if (chunk == null) { + throw segmentNotFound(segmentId); + } + + final ReferenceCountingSegment segment = chunk.getObject(); + final Optional closeable = segment.acquireReferences(); + if (!closeable.isPresent()) { + // Segment has disappeared before we could acquire a reference to it. + throw segmentNotFound(segmentId); + } + + final Closer closer = Closer.create(); + closer.register(closeable.get()); + closer.register(() -> { + final PhysicalSegmentInspector inspector = segment.as(PhysicalSegmentInspector.class); + channelCounters.addFile(inspector != null ? inspector.getNumRows() : 0, 0); + }); + return new ReferenceCountingResourceHolder<>(new CompleteSegment(null, segment), closer); + }; + } + + /** + * Error to throw when a segment that was requested is not found. This can happen due to segment moves, etc. + */ + private static DruidException segmentNotFound(final SegmentId segmentId) + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build("Segment[%s] not found on this server. Please retry your query.", segmentId); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java new file mode 100644 index 000000000000..ff7d9fdc4e9f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartFrameContext.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.WorkerContext; +import org.apache.druid.msq.exec.WorkerMemoryParameters; +import org.apache.druid.msq.exec.WorkerStorageParameters; +import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.segment.IndexIO; +import org.apache.druid.segment.IndexMergerV9; +import org.apache.druid.segment.SegmentWrangler; +import org.apache.druid.segment.incremental.NoopRowIngestionMeters; +import org.apache.druid.segment.incremental.RowIngestionMeters; +import org.apache.druid.segment.loading.DataSegmentPusher; + +import javax.annotation.Nullable; +import java.io.File; + +/** + * Dart implementation of {@link FrameContext}. + */ +public class DartFrameContext implements FrameContext +{ + private final StageId stageId; + private final SegmentWrangler segmentWrangler; + private final GroupingEngine groupingEngine; + private final DataSegmentProvider dataSegmentProvider; + private final WorkerContext workerContext; + @Nullable + private final ResourceHolder processingBuffers; + private final WorkerMemoryParameters memoryParameters; + private final WorkerStorageParameters storageParameters; + + public DartFrameContext( + final StageId stageId, + final WorkerContext workerContext, + final SegmentWrangler segmentWrangler, + final GroupingEngine groupingEngine, + final DataSegmentProvider dataSegmentProvider, + @Nullable ResourceHolder processingBuffers, + final WorkerMemoryParameters memoryParameters, + final WorkerStorageParameters storageParameters + ) + { + this.stageId = stageId; + this.segmentWrangler = segmentWrangler; + this.groupingEngine = groupingEngine; + this.dataSegmentProvider = dataSegmentProvider; + this.workerContext = workerContext; + this.processingBuffers = processingBuffers; + this.memoryParameters = memoryParameters; + this.storageParameters = storageParameters; + } + + @Override + public SegmentWrangler segmentWrangler() + { + return segmentWrangler; + } + + @Override + public GroupingEngine groupingEngine() + { + return groupingEngine; + } + + @Override + public RowIngestionMeters rowIngestionMeters() + { + return new NoopRowIngestionMeters(); + } + + @Override + public DataSegmentProvider dataSegmentProvider() + { + return dataSegmentProvider; + } + + @Override + public File tempDir() + { + return new File(workerContext.tempDir(), stageId.toString()); + } + + @Override + public ObjectMapper jsonMapper() + { + return workerContext.jsonMapper(); + } + + @Override + public IndexIO indexIO() + { + throw new UnsupportedOperationException(); + } + + @Override + public File persistDir() + { + return new File(tempDir(), "persist"); + } + + @Override + public DataSegmentPusher segmentPusher() + { + throw DruidException.defensive("Ingestion not implemented"); + } + + @Override + public IndexMergerV9 indexMerger() + { + throw DruidException.defensive("Ingestion not implemented"); + } + + @Override + public ProcessingBuffers processingBuffers() + { + if (processingBuffers != null) { + return processingBuffers.get(); + } else { + throw new ISE("No processing buffers"); + } + } + + @Override + public WorkerMemoryParameters memoryParameters() + { + return memoryParameters; + } + + @Override + public WorkerStorageParameters storageParameters() + { + return storageParameters; + } + + @Override + public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() + { + // We don't query data servers. This factory won't actually be used, because Dart doesn't allow segmentSource to be + // overridden; it always uses SegmentSource.NONE. (If it is called, some wires got crossed somewhere.) + return null; + } + + @Override + public void close() + { + if (processingBuffers != null) { + processingBuffers.close(); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java new file mode 100644 index 000000000000..e2a7b97c4c2a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartProcessingBuffersProvider.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import org.apache.druid.collections.BlockingPool; +import org.apache.druid.collections.QueueNonBlockingPool; +import org.apache.druid.collections.ReferenceCountingResourceHolder; +import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.error.DruidException; +import org.apache.druid.frame.processor.Bouncer; +import org.apache.druid.msq.exec.ProcessingBuffers; +import org.apache.druid.msq.exec.ProcessingBuffersProvider; +import org.apache.druid.msq.exec.ProcessingBuffersSet; +import org.apache.druid.utils.CloseableUtils; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; + +/** + * Production implementation of {@link ProcessingBuffersProvider} that uses the merge buffer pool. Each call + * to {@link #acquire(int)} acquires one merge buffer and slices it up. + */ +public class DartProcessingBuffersProvider implements ProcessingBuffersProvider +{ + private final BlockingPool mergeBufferPool; + private final int processingThreads; + + public DartProcessingBuffersProvider(BlockingPool mergeBufferPool, int processingThreads) + { + this.mergeBufferPool = mergeBufferPool; + this.processingThreads = processingThreads; + } + + @Override + public ResourceHolder acquire(final int poolSize) + { + if (poolSize == 0) { + return new ReferenceCountingResourceHolder<>(ProcessingBuffersSet.EMPTY, () -> {}); + } + + final List> batch = mergeBufferPool.takeBatch(1, 0); + if (batch.isEmpty()) { + throw DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build("No merge buffers available, cannot execute query"); + } + + final ReferenceCountingResourceHolder bufferHolder = batch.get(0); + try { + final ByteBuffer buffer = bufferHolder.get().duplicate(); + final int sliceSize = buffer.capacity() / poolSize / processingThreads; + final List pool = new ArrayList<>(poolSize); + + for (int i = 0; i < poolSize; i++) { + final BlockingQueue queue = new ArrayBlockingQueue<>(processingThreads); + for (int j = 0; j < processingThreads; j++) { + final int sliceNum = i * processingThreads + j; + buffer.position(sliceSize * sliceNum).limit(sliceSize * (sliceNum + 1)); + queue.add(buffer.slice()); + } + final ProcessingBuffers buffers = new ProcessingBuffers( + new QueueNonBlockingPool<>(queue), + new Bouncer(processingThreads) + ); + pool.add(buffers); + } + + return new ReferenceCountingResourceHolder<>(new ProcessingBuffersSet(pool), bufferHolder); + } + catch (Throwable e) { + throw CloseableUtils.closeAndWrapInCatch(e, bufferHolder); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartQueryableSegment.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartQueryableSegment.java new file mode 100644 index 000000000000..574601517b44 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartQueryableSegment.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.google.common.base.Preconditions; +import org.apache.druid.timeline.DataSegment; +import org.joda.time.Interval; + +import java.util.Objects; + +/** + * Represents a segment that is queryable at a specific worker number. + */ +public class DartQueryableSegment +{ + private final DataSegment segment; + private final Interval interval; + private final int workerNumber; + + public DartQueryableSegment(final DataSegment segment, final Interval interval, final int workerNumber) + { + this.segment = Preconditions.checkNotNull(segment, "segment"); + this.interval = Preconditions.checkNotNull(interval, "interval"); + this.workerNumber = workerNumber; + } + + public DataSegment getSegment() + { + return segment; + } + + public Interval getInterval() + { + return interval; + } + + public int getWorkerNumber() + { + return workerNumber; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DartQueryableSegment that = (DartQueryableSegment) o; + return workerNumber == that.workerNumber + && Objects.equals(segment, that.segment) + && Objects.equals(interval, that.interval); + } + + @Override + public int hashCode() + { + return Objects.hash(segment, interval, workerNumber); + } + + @Override + public String toString() + { + return "QueryableDataSegment{" + + "segment=" + segment + + ", interval=" + interval + + ", workerNumber=" + workerNumber + + '}'; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerClient.java new file mode 100644 index 000000000000..932300de217f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerClient.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import it.unimi.dsi.fastutil.Pair; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.java.util.http.client.response.HttpResponseHandler; +import org.apache.druid.msq.dart.controller.DartWorkerManager; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.msq.exec.WorkerClient; +import org.apache.druid.msq.rpc.BaseWorkerClientImpl; +import org.apache.druid.rpc.FixedServiceLocator; +import org.apache.druid.rpc.IgnoreHttpResponseHandler; +import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.rpc.ServiceClient; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.rpc.ServiceLocation; +import org.apache.druid.rpc.ServiceRetryPolicy; +import org.apache.druid.utils.CloseableUtils; +import org.jboss.netty.handler.codec.http.HttpMethod; + +import javax.annotation.Nullable; +import java.io.Closeable; +import java.net.URI; +import java.util.HashMap; +import java.util.Map; + +/** + * Dart implementation of {@link WorkerClient}. Uses the same {@link BaseWorkerClientImpl} as the task-based engine. + * Each instance of this class is scoped to a single query. + */ +public class DartWorkerClient extends BaseWorkerClientImpl +{ + private static final Logger log = new Logger(DartWorkerClient.class); + + private final String queryId; + private final ServiceClientFactory clientFactory; + private final ServiceRetryPolicy retryPolicy; + + @Nullable + private final String controllerHost; + + @GuardedBy("clientMap") + private final Map> clientMap = new HashMap<>(); + + /** + * Create a worker client. + * + * @param queryId dart query ID. see {@link DartSqlEngine#CTX_DART_QUERY_ID} + * @param clientFactory service client factor + * @param smileMapper Smile object mapper + * @param controllerHost Controller host (see {@link DartWorkerResource#HEADER_CONTROLLER_HOST}) if this is a + * controller-to-worker client. Null if this is a worker-to-worker client. + */ + public DartWorkerClient( + final String queryId, + final ServiceClientFactory clientFactory, + final ObjectMapper smileMapper, + @Nullable final String controllerHost + ) + { + super(smileMapper, SmileMediaTypes.APPLICATION_JACKSON_SMILE); + this.queryId = queryId; + this.clientFactory = clientFactory; + this.controllerHost = controllerHost; + + if (controllerHost == null) { + // worker -> worker client. Retry HTTP 503 in case worker A starts up before worker B, and needs to + // contact it immediately. + this.retryPolicy = new DartWorkerRetryPolicy(true); + } else { + // controller -> worker client. Do not retry any HTTP error codes. If we retry HTTP 503 for controller -> worker, + // we can get stuck trying to contact workers that have exited. + this.retryPolicy = new DartWorkerRetryPolicy(false); + } + } + + @Override + protected ServiceClient getClient(final String workerIdString) + { + final WorkerId workerId = WorkerId.fromString(workerIdString); + if (!queryId.equals(workerId.getQueryId())) { + throw DruidException.defensive("Unexpected queryId[%s]. Expected queryId[%s]", workerId.getQueryId(), queryId); + } + + synchronized (clientMap) { + return clientMap.computeIfAbsent(workerId.getHostAndPort(), ignored -> makeNewClient(workerId)).left(); + } + } + + /** + * Close a single worker's clients. Used when that worker fails, so we stop trying to contact it. + * + * @param workerHost worker host:port + */ + public void closeClient(final String workerHost) + { + synchronized (clientMap) { + final Pair clientPair = clientMap.remove(workerHost); + if (clientPair != null) { + CloseableUtils.closeAndWrapExceptions(clientPair.right()); + } + } + } + + /** + * Close all outstanding clients. + */ + @Override + public void close() + { + synchronized (clientMap) { + for (Map.Entry> entry : clientMap.entrySet()) { + CloseableUtils.closeAndSuppressExceptions( + entry.getValue().right(), + e -> log.warn(e, "Failed to close client[%s]", entry.getKey()) + ); + } + + clientMap.clear(); + } + } + + /** + * Stops a worker. Dart-only API, used by the {@link DartWorkerManager}. + */ + public ListenableFuture stopWorker(String workerId) + { + return getClient(workerId).asyncRequest( + new RequestBuilder(HttpMethod.POST, "/stop"), + IgnoreHttpResponseHandler.INSTANCE + ); + } + + /** + * Create a new client. Called by {@link #getClient(String)} if a new one is needed. + */ + private Pair makeNewClient(final WorkerId workerId) + { + final URI uri = workerId.toUri(); + final FixedServiceLocator locator = new FixedServiceLocator(ServiceLocation.fromUri(uri)); + final ServiceClient baseClient = + clientFactory.makeClient(workerId.toString(), locator, retryPolicy); + final ServiceClient client; + + if (controllerHost != null) { + client = new ControllerDecoratedClient(baseClient, controllerHost); + } else { + client = baseClient; + } + + return Pair.of(client, locator); + } + + /** + * Service client that adds the {@link DartWorkerResource#HEADER_CONTROLLER_HOST} header. + */ + private static class ControllerDecoratedClient implements ServiceClient + { + private final ServiceClient delegate; + private final String controllerHost; + + ControllerDecoratedClient(final ServiceClient delegate, final String controllerHost) + { + this.delegate = delegate; + this.controllerHost = controllerHost; + } + + @Override + public ListenableFuture asyncRequest( + final RequestBuilder requestBuilder, + final HttpResponseHandler handler + ) + { + return delegate.asyncRequest( + requestBuilder.header(DartWorkerResource.HEADER_CONTROLLER_HOST, controllerHost), + handler + ); + } + + @Override + public ServiceClient withRetryPolicy(final ServiceRetryPolicy retryPolicy) + { + return new ControllerDecoratedClient(delegate.withRetryPolicy(retryPolicy), controllerHost); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java new file mode 100644 index 000000000000..525162fd8ddd --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerContext.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Preconditions; +import com.google.inject.Injector; +import org.apache.druid.collections.ResourceHolder; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.io.Closer; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.exec.ControllerClient; +import org.apache.druid.msq.exec.DataServerQueryHandlerFactory; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.ProcessingBuffersProvider; +import org.apache.druid.msq.exec.ProcessingBuffersSet; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.exec.WorkerClient; +import org.apache.druid.msq.exec.WorkerContext; +import org.apache.druid.msq.exec.WorkerMemoryParameters; +import org.apache.druid.msq.exec.WorkerStorageParameters; +import org.apache.druid.msq.kernel.FrameContext; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.msq.util.MultiStageQueryContext; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.query.QueryContext; +import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.segment.SegmentWrangler; +import org.apache.druid.server.DruidNode; +import org.checkerframework.checker.nullness.qual.MonotonicNonNull; + +import java.io.File; + +/** + * Dart implementation of {@link WorkerContext}. + * Each instance is scoped to a query. + */ +public class DartWorkerContext implements WorkerContext +{ + private final String queryId; + private final String controllerHost; + private final String workerId; + private final DruidNode selfNode; + private final ObjectMapper jsonMapper; + private final Injector injector; + private final DartWorkerClient workerClient; + private final DruidProcessingConfig processingConfig; + private final SegmentWrangler segmentWrangler; + private final GroupingEngine groupingEngine; + private final DataSegmentProvider dataSegmentProvider; + private final MemoryIntrospector memoryIntrospector; + private final ProcessingBuffersProvider processingBuffersProvider; + private final Outbox outbox; + private final File tempDir; + private final QueryContext queryContext; + + /** + * Lazy initialized upon call to {@link #frameContext(WorkOrder)}. + */ + @MonotonicNonNull + private volatile ResourceHolder processingBuffersSet; + + DartWorkerContext( + final String queryId, + final String controllerHost, + final String workerId, + final DruidNode selfNode, + final ObjectMapper jsonMapper, + final Injector injector, + final DartWorkerClient workerClient, + final DruidProcessingConfig processingConfig, + final SegmentWrangler segmentWrangler, + final GroupingEngine groupingEngine, + final DataSegmentProvider dataSegmentProvider, + final MemoryIntrospector memoryIntrospector, + final ProcessingBuffersProvider processingBuffersProvider, + final Outbox outbox, + final File tempDir, + final QueryContext queryContext + ) + { + this.queryId = queryId; + this.controllerHost = controllerHost; + this.workerId = workerId; + this.selfNode = selfNode; + this.jsonMapper = jsonMapper; + this.injector = injector; + this.workerClient = workerClient; + this.processingConfig = processingConfig; + this.segmentWrangler = segmentWrangler; + this.groupingEngine = groupingEngine; + this.dataSegmentProvider = dataSegmentProvider; + this.memoryIntrospector = memoryIntrospector; + this.processingBuffersProvider = processingBuffersProvider; + this.outbox = outbox; + this.tempDir = tempDir; + this.queryContext = Preconditions.checkNotNull(queryContext, "queryContext"); + } + + @Override + public String queryId() + { + return queryId; + } + + @Override + public String workerId() + { + return workerId; + } + + @Override + public ObjectMapper jsonMapper() + { + return jsonMapper; + } + + @Override + public Injector injector() + { + return injector; + } + + @Override + public void registerWorker(Worker worker, Closer closer) + { + closer.register(() -> { + synchronized (this) { + if (processingBuffersSet != null) { + processingBuffersSet.close(); + processingBuffersSet = null; + } + } + + workerClient.close(); + }); + } + + @Override + public int maxConcurrentStages() + { + final int retVal = MultiStageQueryContext.getMaxConcurrentStagesWithDefault(queryContext, -1); + if (retVal <= 0) { + throw new IAE("Illegal maxConcurrentStages[%s]", retVal); + } + return retVal; + } + + @Override + public ControllerClient makeControllerClient() + { + return new DartControllerClient(outbox, queryId, controllerHost); + } + + @Override + public WorkerClient makeWorkerClient() + { + return workerClient; + } + + @Override + public File tempDir() + { + return tempDir; + } + + @Override + public FrameContext frameContext(WorkOrder workOrder) + { + if (processingBuffersSet == null) { + synchronized (this) { + if (processingBuffersSet == null) { + processingBuffersSet = processingBuffersProvider.acquire( + workOrder.getQueryDefinition(), + maxConcurrentStages() + ); + } + } + } + + final WorkerMemoryParameters memoryParameters = + WorkerMemoryParameters.createProductionInstance( + workOrder, + memoryIntrospector, + maxConcurrentStages() + ); + + final WorkerStorageParameters storageParameters = WorkerStorageParameters.createInstance(-1, false); + + return new DartFrameContext( + workOrder.getStageDefinition().getId(), + this, + segmentWrangler, + groupingEngine, + dataSegmentProvider, + processingBuffersSet.get().acquireForStage(workOrder.getStageDefinition()), + memoryParameters, + storageParameters + ); + } + + @Override + public int threadCount() + { + return processingConfig.getNumThreads(); + } + + @Override + public DataServerQueryHandlerFactory dataServerQueryHandlerFactory() + { + // We don't query data servers. Return null so this factory is ignored when the main worker code tries + // to close it. + return null; + } + + @Override + public boolean includeAllCounters() + { + // The context parameter "includeAllCounters" is meant to assist with backwards compatibility for versions prior + // to Druid 31. Dart didn't exist prior to Druid 31, so there is no need for it here. Always emit all counters. + return true; + } + + @Override + public DruidNode selfNode() + { + return selfNode; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactory.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactory.java new file mode 100644 index 000000000000..429579b2195e --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactory.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.query.QueryContext; + +import java.io.File; + +/** + * Used by {@link DartWorkerRunner} to create new {@link Worker} instances. + */ +public interface DartWorkerFactory +{ + Worker build(String queryId, String controllerHost, File tempDir, QueryContext context); +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactoryImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactoryImpl.java new file mode 100644 index 000000000000..eb2b25252f6a --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerFactoryImpl.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.inject.Inject; +import com.google.inject.Injector; +import org.apache.druid.guice.annotations.EscalatedGlobal; +import org.apache.druid.guice.annotations.Json; +import org.apache.druid.guice.annotations.Self; +import org.apache.druid.guice.annotations.Smile; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.msq.exec.MemoryIntrospector; +import org.apache.druid.msq.exec.ProcessingBuffersProvider; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.exec.WorkerContext; +import org.apache.druid.msq.exec.WorkerImpl; +import org.apache.druid.msq.querykit.DataSegmentProvider; +import org.apache.druid.query.DruidProcessingConfig; +import org.apache.druid.query.QueryContext; +import org.apache.druid.query.groupby.GroupingEngine; +import org.apache.druid.rpc.ServiceClientFactory; +import org.apache.druid.segment.SegmentWrangler; +import org.apache.druid.server.DruidNode; + +import java.io.File; +import java.net.URI; +import java.net.URISyntaxException; + +/** + * Production implementation of {@link DartWorkerFactory}. + */ +public class DartWorkerFactoryImpl implements DartWorkerFactory +{ + private final String id; + private final DruidNode selfNode; + private final ObjectMapper jsonMapper; + private final ObjectMapper smileMapper; + private final Injector injector; + private final ServiceClientFactory serviceClientFactory; + private final DruidProcessingConfig processingConfig; + private final SegmentWrangler segmentWrangler; + private final GroupingEngine groupingEngine; + private final DataSegmentProvider dataSegmentProvider; + private final MemoryIntrospector memoryIntrospector; + private final ProcessingBuffersProvider processingBuffersProvider; + private final Outbox outbox; + + @Inject + public DartWorkerFactoryImpl( + @Self DruidNode selfNode, + @Json ObjectMapper jsonMapper, + @Smile ObjectMapper smileMapper, + Injector injector, + @EscalatedGlobal ServiceClientFactory serviceClientFactory, + DruidProcessingConfig processingConfig, + SegmentWrangler segmentWrangler, + GroupingEngine groupingEngine, + @Dart DataSegmentProvider dataSegmentProvider, + MemoryIntrospector memoryIntrospector, + @Dart ProcessingBuffersProvider processingBuffersProvider, + Outbox outbox + ) + { + this.id = makeWorkerId(selfNode); + this.selfNode = selfNode; + this.jsonMapper = jsonMapper; + this.smileMapper = smileMapper; + this.injector = injector; + this.serviceClientFactory = serviceClientFactory; + this.processingConfig = processingConfig; + this.segmentWrangler = segmentWrangler; + this.groupingEngine = groupingEngine; + this.dataSegmentProvider = dataSegmentProvider; + this.memoryIntrospector = memoryIntrospector; + this.processingBuffersProvider = processingBuffersProvider; + this.outbox = outbox; + } + + @Override + public Worker build(String queryId, String controllerHost, File tempDir, QueryContext queryContext) + { + final WorkerContext workerContext = new DartWorkerContext( + queryId, + controllerHost, + id, + selfNode, + jsonMapper, + injector, + new DartWorkerClient(queryId, serviceClientFactory, smileMapper, null), + processingConfig, + segmentWrangler, + groupingEngine, + dataSegmentProvider, + memoryIntrospector, + processingBuffersProvider, + outbox, + tempDir, + queryContext + ); + + return new WorkerImpl(null, workerContext); + } + + private static String makeWorkerId(final DruidNode selfNode) + { + try { + return new URI( + selfNode.getServiceScheme(), + null, + selfNode.getHost(), + selfNode.getPortToUse(), + DartWorkerResource.PATH, + null, + null + ).toString(); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRetryPolicy.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRetryPolicy.java new file mode 100644 index 000000000000..5dbfe98ef0c5 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRetryPolicy.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import org.apache.druid.rpc.ServiceRetryPolicy; +import org.apache.druid.rpc.StandardRetryPolicy; +import org.jboss.netty.handler.codec.http.HttpResponse; +import org.jboss.netty.handler.codec.http.HttpResponseStatus; + +/** + * Retry policy for {@link DartWorkerClient}. This is a {@link StandardRetryPolicy#unlimited()} with + * {@link #retryHttpResponse(HttpResponse)} customized to retry fewer HTTP error codes. + */ +public class DartWorkerRetryPolicy implements ServiceRetryPolicy +{ + private final boolean retryOnWorkerUnavailable; + + /** + * Create a retry policy. + * + * @param retryOnWorkerUnavailable whether this policy should retry on {@link HttpResponseStatus#SERVICE_UNAVAILABLE} + */ + public DartWorkerRetryPolicy(boolean retryOnWorkerUnavailable) + { + this.retryOnWorkerUnavailable = retryOnWorkerUnavailable; + } + + @Override + public long maxAttempts() + { + return StandardRetryPolicy.unlimited().maxAttempts(); + } + + @Override + public long minWaitMillis() + { + return StandardRetryPolicy.unlimited().minWaitMillis(); + } + + @Override + public long maxWaitMillis() + { + return StandardRetryPolicy.unlimited().maxWaitMillis(); + } + + @Override + public boolean retryHttpResponse(HttpResponse response) + { + if (retryOnWorkerUnavailable) { + return HttpResponseStatus.SERVICE_UNAVAILABLE.equals(response.getStatus()); + } else { + return false; + } + } + + @Override + public boolean retryThrowable(Throwable t) + { + return StandardRetryPolicy.unlimited().retryThrowable(t); + } + + @Override + public boolean retryLoggable() + { + return false; + } + + @Override + public boolean retryNotAvailable() + { + return false; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRunner.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRunner.java new file mode 100644 index 000000000000..ae136196a0fc --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/DartWorkerRunner.java @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.error.DruidException; +import org.apache.druid.guice.ManageLifecycle; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.FileUtils; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.lifecycle.LifecycleStart; +import org.apache.druid.java.util.common.lifecycle.LifecycleStop; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.dart.worker.http.DartWorkerInfo; +import org.apache.druid.msq.dart.worker.http.GetWorkersResponse; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.indexing.error.CanceledFault; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.msq.rpc.WorkerResource; +import org.apache.druid.query.QueryContext; +import org.apache.druid.server.security.AuthorizerMapper; +import org.joda.time.DateTime; + +import javax.annotation.Nullable; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +@ManageLifecycle +public class DartWorkerRunner +{ + private static final Logger log = new Logger(DartWorkerRunner.class); + + /** + * Set of active controllers. Ignore requests from others. + */ + @GuardedBy("this") + private final Set activeControllerHosts = new HashSet<>(); + + /** + * Query ID -> Worker instance. + */ + @GuardedBy("this") + private final Map workerMap = new HashMap<>(); + private final DartWorkerFactory workerFactory; + private final ExecutorService workerExec; + private final DruidNodeDiscoveryProvider discoveryProvider; + private final ResourcePermissionMapper permissionMapper; + private final AuthorizerMapper authorizerMapper; + private final File baseTempDir; + + public DartWorkerRunner( + final DartWorkerFactory workerFactory, + final ExecutorService workerExec, + final DruidNodeDiscoveryProvider discoveryProvider, + final ResourcePermissionMapper permissionMapper, + final AuthorizerMapper authorizerMapper, + final File baseTempDir + ) + { + this.workerFactory = workerFactory; + this.workerExec = workerExec; + this.discoveryProvider = discoveryProvider; + this.permissionMapper = permissionMapper; + this.authorizerMapper = authorizerMapper; + this.baseTempDir = baseTempDir; + } + + /** + * Start a worker, creating a holder for it. If a worker with this query ID is already started, does nothing. + * Returns the worker. + * + * @throws DruidException if the controllerId does not correspond to a currently-active controller + */ + public Worker startWorker( + final String queryId, + final String controllerHost, + final QueryContext context + ) + { + final WorkerHolder holder; + final boolean newHolder; + + synchronized (this) { + if (!activeControllerHosts.contains(controllerHost)) { + throw DruidException.forPersona(DruidException.Persona.OPERATOR) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .build("Received startWorker request for unknown controller[%s]", controllerHost); + } + + final WorkerHolder existingHolder = workerMap.get(queryId); + if (existingHolder != null) { + holder = existingHolder; + newHolder = false; + } else { + final Worker worker = workerFactory.build(queryId, controllerHost, baseTempDir, context); + final WorkerResource resource = new WorkerResource(worker, permissionMapper, authorizerMapper); + holder = new WorkerHolder(worker, controllerHost, resource, DateTimes.nowUtc()); + workerMap.put(queryId, holder); + this.notifyAll(); + newHolder = true; + } + } + + if (newHolder) { + workerExec.submit(() -> { + final String originalThreadName = Thread.currentThread().getName(); + try { + Thread.currentThread().setName(StringUtils.format("%s[%s]", originalThreadName, queryId)); + holder.worker.run(); + } + catch (Throwable t) { + if (Thread.interrupted() + || t instanceof MSQException && ((MSQException) t).getFault().getErrorCode().equals(CanceledFault.CODE)) { + log.debug(t, "Canceled, exiting thread."); + } else { + log.warn(t, "Worker for query[%s] failed and stopped.", queryId); + } + } + finally { + synchronized (this) { + workerMap.remove(queryId, holder); + this.notifyAll(); + } + + Thread.currentThread().setName(originalThreadName); + } + }); + } + + return holder.worker; + } + + /** + * Stops a worker. + */ + public void stopWorker(final String queryId) + { + final WorkerHolder holder; + + synchronized (this) { + holder = workerMap.get(queryId); + } + + if (holder != null) { + holder.worker.stop(); + } + } + + /** + * Get the worker resource handler for a query ID if it exists. Returns null if the worker is not running. + */ + @Nullable + public WorkerResource getWorkerResource(final String queryId) + { + synchronized (this) { + final WorkerHolder holder = workerMap.get(queryId); + if (holder != null) { + return holder.resource; + } else { + return null; + } + } + } + + /** + * Returns a {@link GetWorkersResponse} with information about all active workers. + */ + public GetWorkersResponse getWorkersResponse() + { + final List infos = new ArrayList<>(); + + synchronized (this) { + for (final Map.Entry entry : workerMap.entrySet()) { + final String queryId = entry.getKey(); + final WorkerHolder workerHolder = entry.getValue(); + infos.add( + new DartWorkerInfo( + queryId, + WorkerId.fromString(workerHolder.worker.id()), + workerHolder.controllerHost, + workerHolder.acceptTime + ) + ); + } + } + + return new GetWorkersResponse(infos); + } + + @LifecycleStart + public void start() + { + createAndCleanTempDirectory(); + + final DruidNodeDiscovery brokers = discoveryProvider.getForNodeRole(NodeRole.BROKER); + brokers.registerListener(new BrokerListener()); + } + + @LifecycleStop + public void stop() + { + synchronized (this) { + final Collection holders = workerMap.values(); + + for (final WorkerHolder holder : holders) { + holder.worker.stop(); + } + + for (final WorkerHolder holder : holders) { + holder.worker.awaitStop(); + } + } + } + + /** + * Method for testing. Waits for the set of queries to match a given predicate. + */ + @VisibleForTesting + void awaitQuerySet(Predicate> queryIdsPredicate) throws InterruptedException + { + synchronized (this) { + while (!queryIdsPredicate.test(workerMap.keySet())) { + wait(); + } + } + } + + /** + * Creates the {@link #baseTempDir}, and removes any items in it that still exist. + */ + void createAndCleanTempDirectory() + { + try { + FileUtils.mkdirp(baseTempDir); + } + catch (IOException e) { + throw new RuntimeException(e); + } + + final File[] files = baseTempDir.listFiles(); + + if (files != null) { + for (final File file : files) { + if (file.isDirectory()) { + try { + FileUtils.deleteDirectory(file); + log.info("Removed stale query directory[%s].", file); + } + catch (Exception e) { + log.noStackTrace().warn(e, "Could not remove stale query directory[%s], skipping.", file); + } + } + } + } + } + + private static class WorkerHolder + { + private final Worker worker; + private final WorkerResource resource; + private final String controllerHost; + private final DateTime acceptTime; + + public WorkerHolder( + Worker worker, + String controllerHost, + WorkerResource resource, + final DateTime acceptTime + ) + { + this.worker = worker; + this.resource = resource; + this.controllerHost = controllerHost; + this.acceptTime = acceptTime; + } + } + + /** + * Listener that cancels work associated with Brokers that have gone away. + */ + private class BrokerListener implements DruidNodeDiscovery.Listener + { + @Override + public void nodesAdded(Collection nodes) + { + synchronized (DartWorkerRunner.this) { + for (final DiscoveryDruidNode node : nodes) { + activeControllerHosts.add(node.getDruidNode().getHostAndPortToUse()); + } + } + } + + @Override + public void nodesRemoved(Collection nodes) + { + final Set hostsRemoved = + nodes.stream().map(node -> node.getDruidNode().getHostAndPortToUse()).collect(Collectors.toSet()); + + final List workersToNotify = new ArrayList<>(); + + synchronized (DartWorkerRunner.this) { + activeControllerHosts.removeAll(hostsRemoved); + + for (Map.Entry entry : workerMap.entrySet()) { + if (hostsRemoved.contains(entry.getValue().controllerHost)) { + workersToNotify.add(entry.getValue().worker); + } + } + } + + for (final Worker worker : workersToNotify) { + worker.controllerFailed(); + } + } + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/WorkerId.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/WorkerId.java new file mode 100644 index 000000000000..2bbff7111ca7 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/WorkerId.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonValue; +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.msq.dart.worker.http.DartWorkerResource; +import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.coordination.DruidServerMetadata; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Worker IDs, of the type returned by {@link ControllerQueryKernelConfig#getWorkerIds()}. + * + * Dart workerIds are strings of the form "scheme:host:port:queryId", like + * "https:host1.example.com:8083:2f05528c-a882-4da5-8b7d-2ecafb7f3f4e". + */ +public class WorkerId +{ + private static final Pattern PATTERN = Pattern.compile("^(\\w+):(.+:\\d+):([a-z0-9-]+)$"); + + private final String scheme; + private final String hostAndPort; + private final String queryId; + private final String fullString; + + public WorkerId(final String scheme, final String hostAndPort, final String queryId) + { + this.scheme = Preconditions.checkNotNull(scheme, "scheme"); + this.hostAndPort = Preconditions.checkNotNull(hostAndPort, "hostAndPort"); + this.queryId = Preconditions.checkNotNull(queryId, "queryId"); + this.fullString = Joiner.on(':').join(scheme, hostAndPort, queryId); + } + + @JsonCreator + public static WorkerId fromString(final String s) + { + if (s == null) { + throw new IAE("Missing workerId"); + } + + final Matcher matcher = PATTERN.matcher(s); + if (matcher.matches()) { + return new WorkerId(matcher.group(1), matcher.group(2), matcher.group(3)); + } else { + throw new IAE("Invalid workerId[%s]", s); + } + } + + /** + * Create a worker ID, which is a URL. + */ + public static WorkerId fromDruidNode(final DruidNode node, final String queryId) + { + return new WorkerId( + node.getServiceScheme(), + node.getHostAndPortToUse(), + queryId + ); + } + + /** + * Create a worker ID, which is a URL. + */ + public static WorkerId fromDruidServerMetadata(final DruidServerMetadata server, final String queryId) + { + return new WorkerId( + server.getHostAndTlsPort() != null ? "https" : "http", + server.getHost(), + queryId + ); + } + + public String getScheme() + { + return scheme; + } + + public String getHostAndPort() + { + return hostAndPort; + } + + public String getQueryId() + { + return queryId; + } + + public URI toUri() + { + try { + final String path = StringUtils.format( + "%s/workers/%s", + DartWorkerResource.PATH, + StringUtils.urlEncode(queryId) + ); + + return new URI(scheme, hostAndPort, path, null, null); + } + catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Override + @JsonValue + public String toString() + { + return fullString; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + WorkerId workerId = (WorkerId) o; + return Objects.equals(fullString, workerId.fullString); + } + + @Override + public int hashCode() + { + return fullString.hashCode(); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfo.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfo.java new file mode 100644 index 000000000000..3bd14993ded8 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfo.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker.http; + +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.druid.msq.dart.controller.http.DartQueryInfo; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.joda.time.DateTime; + +import java.util.Objects; + +/** + * Class included in {@link GetWorkersResponse}. + */ +public class DartWorkerInfo +{ + private final String dartQueryId; + private final WorkerId workerId; + private final String controllerHost; + private final DateTime startTime; + + public DartWorkerInfo( + @JsonProperty("dartQueryId") final String dartQueryId, + @JsonProperty("workerId") final WorkerId workerId, + @JsonProperty("controllerHost") final String controllerHost, + @JsonProperty("startTime") final DateTime startTime + ) + { + this.dartQueryId = dartQueryId; + this.workerId = workerId; + this.controllerHost = controllerHost; + this.startTime = startTime; + } + + /** + * Dart query ID generated by the system. Globally unique. + */ + @JsonProperty + public String getDartQueryId() + { + return dartQueryId; + } + + /** + * Worker ID for this query. + */ + @JsonProperty + public WorkerId getWorkerId() + { + return workerId; + } + + /** + * Controller host:port that manages this query. + */ + @JsonProperty + public String getControllerHost() + { + return controllerHost; + } + + /** + * Time this query was accepted by this worker. May be somewhat later than the {@link DartQueryInfo#getStartTime()} + * on the controller. + */ + @JsonProperty + public DateTime getStartTime() + { + return startTime; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DartWorkerInfo that = (DartWorkerInfo) o; + return Objects.equals(dartQueryId, that.dartQueryId) + && Objects.equals(workerId, that.workerId) + && Objects.equals(controllerHost, that.controllerHost) + && Objects.equals(startTime, that.startTime); + } + + @Override + public int hashCode() + { + return Objects.hash(dartQueryId, workerId, controllerHost, startTime); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerResource.java new file mode 100644 index 000000000000..03fd847cb1af --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/DartWorkerResource.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker.http; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.inject.Inject; +import org.apache.druid.error.DruidException; +import org.apache.druid.guice.LazySingleton; +import org.apache.druid.guice.annotations.Smile; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.messages.server.MessageRelayResource; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.msq.dart.Dart; +import org.apache.druid.msq.dart.controller.messages.ControllerMessage; +import org.apache.druid.msq.dart.worker.DartWorkerRunner; +import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.rpc.MSQResourceUtils; +import org.apache.druid.msq.rpc.ResourcePermissionMapper; +import org.apache.druid.msq.rpc.WorkerResource; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.initialization.jetty.ServiceUnavailableException; +import org.apache.druid.server.security.AuthorizerMapper; + +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.Consumes; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; + +/** + * Subclass of {@link WorkerResource} suitable for usage on a Historical. + * + * Note that this is not the same resource as used by {@link org.apache.druid.msq.indexing.MSQWorkerTask}. + * For that, see {@link org.apache.druid.msq.indexing.client.WorkerChatHandler}. + */ +@LazySingleton +@Path(DartWorkerResource.PATH + '/') +public class DartWorkerResource +{ + /** + * Root of worker APIs. + */ + public static final String PATH = "/druid/dart-worker"; + + /** + * Header containing the controller host:port, from {@link DruidNode#getHostAndPortToUse()}. + */ + public static final String HEADER_CONTROLLER_HOST = "X-Dart-Controller-Host"; + + private final DartWorkerRunner workerRunner; + private final ResourcePermissionMapper permissionMapper; + private final AuthorizerMapper authorizerMapper; + private final MessageRelayResource messageRelayResource; + + @Inject + public DartWorkerResource( + final DartWorkerRunner workerRunner, + @Dart final ResourcePermissionMapper permissionMapper, + @Smile final ObjectMapper smileMapper, + final Outbox outbox, + final AuthorizerMapper authorizerMapper + ) + { + this.workerRunner = workerRunner; + this.permissionMapper = permissionMapper; + this.authorizerMapper = authorizerMapper; + this.messageRelayResource = new MessageRelayResource<>( + outbox, + smileMapper, + ControllerMessage.class + ); + } + + /** + * API for retrieving all currently-running queries. + */ + @GET + @Produces(MediaType.APPLICATION_JSON) + @Path("/workers") + public GetWorkersResponse httpGetWorkers(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + return workerRunner.getWorkersResponse(); + } + + /** + * Like {@link WorkerResource#httpPostWorkOrder(WorkOrder, HttpServletRequest)}, but implicitly starts a worker + * when the work order is posted. Shadows {@link WorkerResource#httpPostWorkOrder(WorkOrder, HttpServletRequest)}. + */ + @POST + @Consumes({MediaType.APPLICATION_JSON, SmileMediaTypes.APPLICATION_JACKSON_SMILE}) + @Path("/workers/{queryId}/workOrder") + public Response httpPostWorkOrder( + final WorkOrder workOrder, + @PathParam("queryId") final String queryId, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + final String controllerHost = req.getHeader(HEADER_CONTROLLER_HOST); + if (controllerHost == null) { + throw DruidException.forPersona(DruidException.Persona.DEVELOPER) + .ofCategory(DruidException.Category.INVALID_INPUT) + .build("Missing controllerId[%s]", HEADER_CONTROLLER_HOST); + } + + workerRunner.startWorker(queryId, controllerHost, workOrder.getWorkerContext()) + .postWorkOrder(workOrder); + + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * Stops a worker. Returns immediately; does not wait for the worker to actually finish. + */ + @POST + @Path("/workers/{queryId}/stop") + public Response httpPostStopWorker( + @PathParam("queryId") final String queryId, + @Context final HttpServletRequest req + ) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + workerRunner.stopWorker(queryId); + return Response.status(Response.Status.ACCEPTED).build(); + } + + /** + * Handles all {@link WorkerResource} calls, except {@link WorkerResource#httpPostWorkOrder}, which is handled + * by {@link #httpPostWorkOrder(WorkOrder, String, HttpServletRequest)}. + */ + @Path("/workers/{queryId}") + public Object httpCallWorkerResource( + @PathParam("queryId") final String queryId, + @Context final HttpServletRequest req + ) + { + final WorkerResource resource = workerRunner.getWorkerResource(queryId); + + if (resource != null) { + return resource; + } else { + // Return HTTP 503 (Service Unavailable) so worker -> worker clients can retry. When workers are first starting + // up and contacting each other, worker A may contact worker B before worker B has started up. In the future, it + // would be better to do an async wait, with some timeout, for the worker to show up before returning 503. + // That way a retry wouldn't be necessary. + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + throw new ServiceUnavailableException(StringUtils.format("No worker running for query[%s]", queryId)); + } + } + + @Path("/relay") + public Object httpCallMessageRelayServer(@Context final HttpServletRequest req) + { + MSQResourceUtils.authorizeAdminRequest(permissionMapper, authorizerMapper, req); + return messageRelayResource; + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponse.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponse.java new file mode 100644 index 000000000000..0fa28a4ef17f --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponse.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker.http; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.servlet.http.HttpServletRequest; +import java.util.List; +import java.util.Objects; + +/** + * Response from {@link DartWorkerResource#httpGetWorkers(HttpServletRequest)}, the "get all workers" API. + */ +public class GetWorkersResponse +{ + private final List workers; + + public GetWorkersResponse(@JsonProperty("workers") final List workers) + { + this.workers = workers; + } + + @JsonProperty + public List getWorkers() + { + return workers; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GetWorkersResponse that = (GetWorkersResponse) o; + return Objects.equals(workers, that.workers); + } + + @Override + public int hashCode() + { + return Objects.hashCode(workers); + } +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java index d2370b057935..f4a2448595fb 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/Controller.java @@ -22,6 +22,8 @@ import org.apache.druid.indexer.report.TaskReport; import org.apache.druid.msq.counters.CounterSnapshots; import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.dart.controller.http.DartSqlResource; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.client.ControllerChatHandler; import org.apache.druid.msq.indexing.error.MSQErrorReport; @@ -42,6 +44,7 @@ public interface Controller * Unique task/query ID for the batch query run by this controller. * * Controller IDs must be globally unique. For tasks, this is the task ID from {@link MSQControllerTask#getId()}. + * For Dart, this is {@link DartSqlEngine#CTX_DART_QUERY_ID}, set by {@link DartSqlResource}. */ String queryId(); @@ -121,6 +124,11 @@ void resultsComplete( */ List getWorkerIds(); + /** + * Returns whether this controller has a worker with the given ID. + */ + boolean hasWorker(String workerId); + @Nullable TaskReport.ReportMap liveReports(); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 2d0a6212a0ad..60e0910e15b6 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -1168,6 +1168,16 @@ public List getWorkerIds() return workerManager.getWorkerIds(); } + @Override + public boolean hasWorker(String workerId) + { + if (workerManager == null) { + return false; + } + + return workerManager.getWorkerNumber(workerId) != WorkerManager.UNKNOWN_WORKER_NUMBER; + } + @SuppressWarnings({"unchecked", "rawtypes"}) @Nullable private Int2ObjectMap makeWorkerFactoryInfosForStage( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java index 702302f7ea1a..89b1eff11c64 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerImpl.java @@ -79,6 +79,7 @@ import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryProcessingPool; import org.apache.druid.server.DruidNode; +import org.apache.druid.utils.CloseableUtils; import javax.annotation.Nullable; import java.io.Closeable; @@ -988,6 +989,11 @@ private void doCancel() controllerClient.close(); } + // Close worker client to cancel any currently in-flight calls to other workers. + if (workerClient != null) { + CloseableUtils.closeAndSuppressExceptions(workerClient, e -> log.warn("Failed to close workerClient")); + } + // Clear the main loop event queue, then throw a CanceledFault into the loop to exit it promptly. kernelManipulationQueue.clear(); kernelManipulationQueue.add( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java index ebce4821d591..31af0953d2f9 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerManager.java @@ -83,8 +83,11 @@ public interface WorkerManager Map> getWorkerStats(); /** - * Blocks until all workers exit. Returns quietly, no matter whether there was an exception associated with the - * future from {@link #start()} or not. + * Stop all workers. + * + * The task-based implementation blocks until all tasks exit. Dart's implementation queues workers for stopping in + * the background, and returns immediately. Either way, this method returns quietly, no matter whether there was an + * exception associated with the future from {@link #start()} or not. * * @param interrupt whether to interrupt currently-running work */ diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java index 4cc4678a58a7..be73a3cbfdd0 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/TaskReportQueryListener.java @@ -71,6 +71,7 @@ public class TaskReportQueryListener implements QueryListener private JsonGenerator jg; private long numResults; private MSQStatusReport statusReport; + private boolean resultsCurrentlyOpen; public TaskReportQueryListener( final MSQDestination destination, @@ -99,6 +100,7 @@ public void onResultsStart(List signature, @Null { try { openGenerator(); + resultsCurrentlyOpen = true; jg.writeObjectFieldStart(FIELD_RESULTS); writeObjectField(FIELD_RESULTS_SIGNATURE, signature); @@ -118,15 +120,7 @@ public boolean onResultRow(Object[] row) try { JacksonUtils.writeObjectUsingSerializerProvider(jg, serializers, row); numResults++; - - if (rowsInTaskReport == MSQDestination.UNLIMITED || numResults < rowsInTaskReport) { - return true; - } else { - jg.writeEndArray(); - jg.writeBooleanField(FIELD_RESULTS_TRUNCATED, true); - jg.writeEndObject(); - return false; - } + return rowsInTaskReport == MSQDestination.UNLIMITED || numResults < rowsInTaskReport; } catch (IOException e) { throw new RuntimeException(e); @@ -137,6 +131,8 @@ public boolean onResultRow(Object[] row) public void onResultsComplete() { try { + resultsCurrentlyOpen = false; + jg.writeEndArray(); jg.writeBooleanField(FIELD_RESULTS_TRUNCATED, false); jg.writeEndObject(); @@ -150,7 +146,14 @@ public void onResultsComplete() public void onQueryComplete(MSQTaskReportPayload report) { try { - openGenerator(); + if (resultsCurrentlyOpen) { + jg.writeEndArray(); + jg.writeBooleanField(FIELD_RESULTS_TRUNCATED, true); + jg.writeEndObject(); + } else { + openGenerator(); + } + statusReport = report.getStatus(); writeObjectField(FIELD_STATUS, report.getStatus()); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/CanceledFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/CanceledFault.java index c81572a88165..2798a3ccfaa6 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/CanceledFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/CanceledFault.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonTypeName; +import org.apache.druid.error.DruidException; @JsonTypeName(CanceledFault.CODE) public class CanceledFault extends BaseMSQFault @@ -38,4 +39,13 @@ public static CanceledFault instance() { return INSTANCE; } + + @Override + public DruidException toDruidException() + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.CANCELED) + .withErrorCode(getErrorCode()) + .build(MSQFaultUtils.generateMessageWithErrorCode(this)); + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnNameRestrictedFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnNameRestrictedFault.java index c2c4617292e0..0ad60bdb0b03 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnNameRestrictedFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnNameRestrictedFault.java @@ -23,6 +23,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; +import org.apache.druid.error.DruidException; import org.apache.druid.java.util.common.StringUtils; import java.util.Objects; @@ -51,6 +52,14 @@ public String getColumnName() return columnName; } + @Override + public DruidException toDruidException() + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.INVALID_INPUT) + .build(MSQFaultUtils.generateMessageWithErrorCode(this)); + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnTypeNotSupportedFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnTypeNotSupportedFault.java index 91764b4b3988..2337837785ee 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnTypeNotSupportedFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/ColumnTypeNotSupportedFault.java @@ -24,6 +24,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonTypeName; import com.google.common.base.Preconditions; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.write.UnsupportedColumnTypeException; import org.apache.druid.segment.column.ColumnType; @@ -65,6 +66,15 @@ public ColumnType getColumnType() return columnType; } + @Override + public DruidException toDruidException() + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.INVALID_INPUT) + .withErrorCode(getErrorCode()) + .build(MSQFaultUtils.generateMessageWithErrorCode(this)); + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQErrorReport.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQErrorReport.java index 8d90bef32ff2..aa515c8b46dc 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQErrorReport.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQErrorReport.java @@ -25,6 +25,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import it.unimi.dsi.fastutil.ints.IntList; +import org.apache.druid.error.DruidException; import org.apache.druid.frame.processor.FrameRowTooLargeException; import org.apache.druid.frame.write.InvalidFieldException; import org.apache.druid.frame.write.InvalidNullByteException; @@ -138,6 +139,31 @@ public String getExceptionStackTrace() return exceptionStackTrace; } + /** + * Returns a {@link DruidException} "equivalent" of this instance. This is useful until such time as we can migrate + * usages of this class to {@link DruidException}. + */ + public DruidException toDruidException() + { + final DruidException druidException = + error.toDruidException() + .withContext("taskId", taskId); + + if (host != null) { + druidException.withContext("host", host); + } + + if (stageNumber != null) { + druidException.withContext("stageNumber", stageNumber); + } + + if (exceptionStackTrace != null) { + druidException.withContext("exceptionStackTrace", exceptionStackTrace); + } + + return druidException; + } + @Override public boolean equals(Object o) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQFault.java index c36157e0ddca..39efce9d2044 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/MSQFault.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.indexing.error; import com.fasterxml.jackson.annotation.JsonTypeInfo; +import org.apache.druid.error.DruidException; import javax.annotation.Nullable; @@ -36,4 +37,17 @@ public interface MSQFault @Nullable String getErrorMessage(); + /** + * Returns a {@link DruidException} corresponding to this fault. + * + * The default is a {@link DruidException.Category#RUNTIME_FAILURE} targeting {@link DruidException.Persona#USER}. + * Faults with different personas and categories should override this method. + */ + default DruidException toDruidException() + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.RUNTIME_FAILURE) + .withErrorCode(getErrorCode()) + .build(MSQFaultUtils.generateMessageWithErrorCode(this)); + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/QueryNotSupportedFault.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/QueryNotSupportedFault.java index bba058cd5888..7356cc029092 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/QueryNotSupportedFault.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/indexing/error/QueryNotSupportedFault.java @@ -21,6 +21,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonTypeName; +import org.apache.druid.error.DruidException; @JsonTypeName(QueryNotSupportedFault.CODE) public class QueryNotSupportedFault extends BaseMSQFault @@ -33,6 +34,15 @@ public class QueryNotSupportedFault extends BaseMSQFault super(CODE); } + @Override + public DruidException toDruidException() + { + return DruidException.forPersona(DruidException.Persona.USER) + .ofCategory(DruidException.Category.UNSUPPORTED) + .withErrorCode(getErrorCode()) + .build(MSQFaultUtils.generateMessageWithErrorCode(this)); + } + @JsonCreator public static QueryNotSupportedFault instance() { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java index 6ec23119a228..2ed7b1784aef 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java @@ -57,6 +57,8 @@ */ public abstract class BaseWorkerClientImpl implements WorkerClient { + private static final Logger log = new Logger(BaseWorkerClientImpl.class); + private final ObjectMapper objectMapper; private final String contentType; @@ -191,8 +193,6 @@ public ListenableFuture getCounters(String workerId) ); } - private static final Logger log = new Logger(BaseWorkerClientImpl.class); - @Override public ListenableFuture fetchChannelData( String workerId, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java index 839defa6bd9c..20758883ddba 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java @@ -56,6 +56,7 @@ import javax.ws.rs.core.StreamingOutput; import java.io.InputStream; import java.io.OutputStream; +import java.util.concurrent.atomic.AtomicBoolean; public class WorkerResource { @@ -104,6 +105,8 @@ public Response httpGetChannelData( worker.readStageOutput(new StageId(queryId, stageNumber), partitionNumber, offset); final AsyncContext asyncContext = req.startAsync(); + final AtomicBoolean responseResolved = new AtomicBoolean(); + asyncContext.setTimeout(GET_CHANNEL_DATA_TIMEOUT); asyncContext.addListener( new AsyncListener() @@ -116,6 +119,10 @@ public void onComplete(AsyncEvent event) @Override public void onTimeout(AsyncEvent event) { + if (responseResolved.compareAndSet(false, true)) { + return; + } + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); response.setStatus(HttpServletResponse.SC_OK); event.getAsyncContext().complete(); @@ -144,7 +151,11 @@ public void onStartAsync(AsyncEvent event) @Override public void onSuccess(final InputStream inputStream) { - HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + if (!responseResolved.compareAndSet(false, true)) { + return; + } + + final HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); try (final OutputStream outputStream = response.getOutputStream()) { if (inputStream == null) { @@ -188,7 +199,7 @@ public void onSuccess(final InputStream inputStream) @Override public void onFailure(Throwable e) { - if (!dataFuture.isCancelled()) { + if (responseResolved.compareAndSet(false, true)) { try { HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java index ae667a7a5585..6cf1dc504554 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskQueryMaker.java @@ -28,6 +28,7 @@ import org.apache.druid.error.DruidException; import org.apache.druid.error.InvalidInput; import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.java.util.common.Pair; import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; @@ -56,7 +57,6 @@ import org.apache.druid.sql.calcite.parser.DruidSqlIngest; import org.apache.druid.sql.calcite.parser.DruidSqlInsert; import org.apache.druid.sql.calcite.parser.DruidSqlReplace; -import org.apache.druid.sql.calcite.planner.ColumnMapping; import org.apache.druid.sql.calcite.planner.ColumnMappings; import org.apache.druid.sql.calcite.planner.PlannerContext; import org.apache.druid.sql.calcite.planner.QueryUtils; @@ -95,7 +95,6 @@ public class MSQTaskQueryMaker implements QueryMaker private final List> fieldMapping; private final MSQTerminalStageSpecFactory terminalStageSpecFactory; - MSQTaskQueryMaker( @Nullable final IngestDestination targetDataSource, final OverlordClient overlordClient, @@ -119,6 +118,38 @@ public QueryResponse runQuery(final DruidQuery druidQuery) Hook.QUERY_PLAN.run(druidQuery.getQuery()); String taskId = MSQTasks.controllerTaskId(plannerContext.getSqlQueryId()); + final Map taskContext = new HashMap<>(); + taskContext.put(LookupLoadingSpec.CTX_LOOKUP_LOADING_MODE, plannerContext.getLookupLoadingSpec().getMode()); + if (plannerContext.getLookupLoadingSpec().getMode() == LookupLoadingSpec.Mode.ONLY_REQUIRED) { + taskContext.put(LookupLoadingSpec.CTX_LOOKUPS_TO_LOAD, plannerContext.getLookupLoadingSpec().getLookupsToLoad()); + } + + final List> typeList = getTypes(druidQuery, fieldMapping, plannerContext); + + final MSQControllerTask controllerTask = new MSQControllerTask( + taskId, + makeQuerySpec(targetDataSource, druidQuery, fieldMapping, plannerContext, terminalStageSpecFactory), + MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(plannerContext.getSql()), + plannerContext.queryContextMap(), + SqlResults.Context.fromPlannerContext(plannerContext), + typeList.stream().map(typeInfo -> typeInfo.lhs).collect(Collectors.toList()), + typeList.stream().map(typeInfo -> typeInfo.rhs).collect(Collectors.toList()), + taskContext + ); + + FutureUtils.getUnchecked(overlordClient.runTask(taskId, controllerTask), true); + return QueryResponse.withEmptyContext(Sequences.simple(Collections.singletonList(new Object[]{taskId}))); + } + + public static MSQSpec makeQuerySpec( + @Nullable final IngestDestination targetDataSource, + final DruidQuery druidQuery, + final List> fieldMapping, + final PlannerContext plannerContext, + final MSQTerminalStageSpecFactory terminalStageSpecFactory + ) + { + // SQL query context: context provided by the user, and potentially modified by handlers during planning. // Does not directly influence task execution, but it does form the basis for the initial native query context, // which *does* influence task execution. @@ -135,23 +166,18 @@ public QueryResponse runQuery(final DruidQuery druidQuery) MSQMode.populateDefaultQueryContext(msqMode, nativeQueryContext); } - Object segmentGranularity; - try { - segmentGranularity = Optional.ofNullable(plannerContext.queryContext() - .get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) - .orElse(jsonMapper.writeValueAsString(DEFAULT_SEGMENT_GRANULARITY)); - } - catch (JsonProcessingException e) { - // This would only be thrown if we are unable to serialize the DEFAULT_SEGMENT_GRANULARITY, which we don't expect - // to happen - throw DruidException.defensive() - .build( - e, - "Unable to deserialize the DEFAULT_SEGMENT_GRANULARITY in MSQTaskQueryMaker. " - + "This shouldn't have happened since the DEFAULT_SEGMENT_GRANULARITY object is guaranteed to be " - + "serializable. Please raise an issue in case you are seeing this message while executing a query." - ); - } + Object segmentGranularity = + Optional.ofNullable(plannerContext.queryContext().get(DruidSqlInsert.SQL_INSERT_SEGMENT_GRANULARITY)) + .orElseGet(() -> { + try { + return plannerContext.getJsonMapper().writeValueAsString(DEFAULT_SEGMENT_GRANULARITY); + } + catch (JsonProcessingException e) { + // This would only be thrown if we are unable to serialize the DEFAULT_SEGMENT_GRANULARITY, + // which we don't expect to happen. + throw DruidException.defensive().build(e, "Unable to serialize DEFAULT_SEGMENT_GRANULARITY"); + } + }); final int maxNumTasks = MultiStageQueryContext.getMaxNumTasks(sqlQueryContext); @@ -167,7 +193,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) final int rowsPerSegment = MultiStageQueryContext.getRowsPerSegment(sqlQueryContext); final int maxRowsInMemory = MultiStageQueryContext.getRowsInMemory(sqlQueryContext); final Integer maxNumSegments = MultiStageQueryContext.getMaxNumSegments(sqlQueryContext); - final IndexSpec indexSpec = MultiStageQueryContext.getIndexSpec(sqlQueryContext, jsonMapper); + final IndexSpec indexSpec = MultiStageQueryContext.getIndexSpec(sqlQueryContext, plannerContext.getJsonMapper()); final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(sqlQueryContext); final List replaceTimeChunks = @@ -190,29 +216,6 @@ public QueryResponse runQuery(final DruidQuery druidQuery) ) .orElse(null); - // For assistance computing return types if !finalizeAggregations. - final Map aggregationIntermediateTypeMap = - finalizeAggregations ? null /* Not needed */ : buildAggregationIntermediateTypeMap(druidQuery); - - final List sqlTypeNames = new ArrayList<>(); - final List columnTypeList = new ArrayList<>(); - final List columnMappings = QueryUtils.buildColumnMappings(fieldMapping, druidQuery); - - for (final Entry entry : fieldMapping) { - final String queryColumn = druidQuery.getOutputRowSignature().getColumnName(entry.getKey()); - - final SqlTypeName sqlTypeName; - - if (!finalizeAggregations && aggregationIntermediateTypeMap.containsKey(queryColumn)) { - final ColumnType druidType = aggregationIntermediateTypeMap.get(queryColumn); - sqlTypeName = new RowSignatures.ComplexSqlType(SqlTypeName.OTHER, druidType, true).getSqlTypeName(); - } else { - sqlTypeName = druidQuery.getOutputRowType().getFieldList().get(entry.getKey()).getType().getSqlTypeName(); - } - sqlTypeNames.add(sqlTypeName); - columnTypeList.add(druidQuery.getOutputRowSignature().getColumnType(queryColumn).orElse(ColumnType.STRING)); - } - final MSQDestination destination; if (targetDataSource instanceof ExportDestination) { @@ -226,7 +229,8 @@ public QueryResponse runQuery(final DruidQuery druidQuery) } else if (targetDataSource instanceof TableDestination) { Granularity segmentGranularityObject; try { - segmentGranularityObject = jsonMapper.readValue((String) segmentGranularity, Granularity.class); + segmentGranularityObject = + plannerContext.getJsonMapper().readValue((String) segmentGranularity, Granularity.class); } catch (Exception e) { throw DruidException.defensive() @@ -285,7 +289,7 @@ public QueryResponse runQuery(final DruidQuery druidQuery) final MSQSpec querySpec = MSQSpec.builder() .query(druidQuery.getQuery().withOverriddenContext(nativeQueryContextOverrides)) - .columnMappings(new ColumnMappings(columnMappings)) + .columnMappings(new ColumnMappings(QueryUtils.buildColumnMappings(fieldMapping, druidQuery))) .destination(destination) .assignmentStrategy(MultiStageQueryContext.getAssignmentStrategy(sqlQueryContext)) .tuningConfig(new MSQTuningConfig(maxNumWorkers, maxRowsInMemory, rowsPerSegment, maxNumSegments, indexSpec)) @@ -293,25 +297,42 @@ public QueryResponse runQuery(final DruidQuery druidQuery) MSQTaskQueryMakerUtils.validateRealtimeReindex(querySpec); - final Map context = new HashMap<>(); - context.put(LookupLoadingSpec.CTX_LOOKUP_LOADING_MODE, plannerContext.getLookupLoadingSpec().getMode()); - if (plannerContext.getLookupLoadingSpec().getMode() == LookupLoadingSpec.Mode.ONLY_REQUIRED) { - context.put(LookupLoadingSpec.CTX_LOOKUPS_TO_LOAD, plannerContext.getLookupLoadingSpec().getLookupsToLoad()); - } + return querySpec.withOverriddenContext(nativeQueryContext); + } - final MSQControllerTask controllerTask = new MSQControllerTask( - taskId, - querySpec.withOverriddenContext(nativeQueryContext), - MSQTaskQueryMakerUtils.maskSensitiveJsonKeys(plannerContext.getSql()), - plannerContext.queryContextMap(), - SqlResults.Context.fromPlannerContext(plannerContext), - sqlTypeNames, - columnTypeList, - context - ); + public static List> getTypes( + final DruidQuery druidQuery, + final List> fieldMapping, + final PlannerContext plannerContext + ) + { + final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(plannerContext.queryContext()); - FutureUtils.getUnchecked(overlordClient.runTask(taskId, controllerTask), true); - return QueryResponse.withEmptyContext(Sequences.simple(Collections.singletonList(new Object[]{taskId}))); + // For assistance computing return types if !finalizeAggregations. + final Map aggregationIntermediateTypeMap = + finalizeAggregations ? null /* Not needed */ : buildAggregationIntermediateTypeMap(druidQuery); + + final List> retVal = new ArrayList<>(); + + for (final Entry entry : fieldMapping) { + final String queryColumn = druidQuery.getOutputRowSignature().getColumnName(entry.getKey()); + + final SqlTypeName sqlTypeName; + + if (!finalizeAggregations && aggregationIntermediateTypeMap.containsKey(queryColumn)) { + final ColumnType druidType = aggregationIntermediateTypeMap.get(queryColumn); + sqlTypeName = new RowSignatures.ComplexSqlType(SqlTypeName.OTHER, druidType, true).getSqlTypeName(); + } else { + sqlTypeName = druidQuery.getOutputRowType().getFieldList().get(entry.getKey()).getType().getSqlTypeName(); + } + + final ColumnType columnType = + druidQuery.getOutputRowSignature().getColumnType(queryColumn).orElse(ColumnType.STRING); + + retVal.add(Pair.of(sqlTypeName, columnType)); + } + + return retVal; } private static Map buildAggregationIntermediateTypeMap(final DruidQuery druidQuery) diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java index 1964ad3de4ca..31a2f5e5e643 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/MSQTaskSqlEngine.java @@ -42,6 +42,7 @@ import org.apache.druid.java.util.common.StringUtils; import org.apache.druid.java.util.common.granularity.Granularities; import org.apache.druid.java.util.common.granularity.Granularity; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; import org.apache.druid.msq.indexing.destination.MSQTerminalStageSpecFactory; import org.apache.druid.msq.querykit.QueryKitUtils; import org.apache.druid.msq.util.ArrayIngestMode; @@ -73,6 +74,9 @@ public class MSQTaskSqlEngine implements SqlEngine { + /** + * Context parameters disallowed for all MSQ engines: task (this one) as well as {@link DartSqlEngine#toString()}. + */ public static final Set SYSTEM_CONTEXT_PARAMETERS = ImmutableSet.builder() .addAll(NativeSqlEngine.SYSTEM_CONTEXT_PARAMETERS) @@ -113,13 +117,21 @@ public void validateContext(Map queryContext) } @Override - public RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { return getMSQStructType(typeFactory); } @Override - public RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { return getMSQStructType(typeFactory); } @@ -387,7 +399,11 @@ private static void validateTypeChanges( final ColumnType oldDruidType = Calcites.getColumnTypeForRelDataType(oldSqlTypeField.getType()); final RelDataType newSqlType = rootRel.getRowType().getFieldList().get(columnIndex).getType(); final ColumnType newDruidType = - DimensionSchemaUtils.getDimensionType(columnName, Calcites.getColumnTypeForRelDataType(newSqlType), arrayIngestMode); + DimensionSchemaUtils.getDimensionType( + columnName, + Calcites.getColumnTypeForRelDataType(newSqlType), + arrayIngestMode + ); if (newDruidType.isArray() && oldDruidType.is(ValueType.STRING) || (newDruidType.is(ValueType.STRING) && oldDruidType.isArray())) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java index a30c9bb0aec0..36c90a21f002 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MSQTaskQueryMakerUtils.java @@ -28,7 +28,6 @@ import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.destination.DataSourceMSQDestination; -import java.util.HashSet; import java.util.List; import java.util.Set; import java.util.regex.Matcher; @@ -82,10 +81,8 @@ public static void validateContextSortOrderColumnsExist( final Set allOutputColumns ) { - final Set allOutputColumnsSet = new HashSet<>(allOutputColumns); - for (final String column : contextSortOrder) { - if (!allOutputColumnsSet.contains(column)) { + if (!allOutputColumns.contains(column)) { throw InvalidSqlInput.exception( "Column[%s] from context parameter[%s] does not appear in the query output", column, diff --git a/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule b/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule index 92be5604cb8a..1058d5d5f99e 100644 --- a/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule +++ b/extensions-core/multi-stage-query/src/main/resources/META-INF/services/org.apache.druid.initialization.DruidModule @@ -13,6 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +org.apache.druid.msq.dart.guice.DartControllerMemoryManagementModule +org.apache.druid.msq.dart.guice.DartControllerModule +org.apache.druid.msq.dart.guice.DartWorkerMemoryManagementModule +org.apache.druid.msq.dart.guice.DartWorkerModule org.apache.druid.msq.guice.IndexerMemoryManagementModule org.apache.druid.msq.guice.MSQDurableStorageModule org.apache.druid.msq.guice.MSQExternalDataSourceModule diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicerTest.java new file mode 100644 index 000000000000..be67fe860abf --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartTableInputSpecSlicerTest.java @@ -0,0 +1,488 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Ordering; +import it.unimi.dsi.fastutil.ints.IntList; +import it.unimi.dsi.fastutil.ints.IntLists; +import org.apache.druid.client.DruidServer; +import org.apache.druid.client.TimelineServerView; +import org.apache.druid.client.selector.HighestPriorityTierSelectorStrategy; +import org.apache.druid.client.selector.QueryableDruidServer; +import org.apache.druid.client.selector.RandomServerSelectorStrategy; +import org.apache.druid.client.selector.ServerSelector; +import org.apache.druid.data.input.StringTuple; +import org.apache.druid.java.util.common.Intervals; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.input.InputSlice; +import org.apache.druid.msq.input.NilInputSlice; +import org.apache.druid.msq.input.table.RichSegmentDescriptor; +import org.apache.druid.msq.input.table.SegmentsInputSlice; +import org.apache.druid.msq.input.table.TableInputSpec; +import org.apache.druid.query.TableDataSource; +import org.apache.druid.query.filter.EqualityFilter; +import org.apache.druid.segment.column.ColumnType; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; +import org.apache.druid.testing.InitializedNullHandlingTest; +import org.apache.druid.timeline.DataSegment; +import org.apache.druid.timeline.VersionedIntervalTimeline; +import org.apache.druid.timeline.partition.DimensionRangeShardSpec; +import org.apache.druid.timeline.partition.NumberedShardSpec; +import org.apache.druid.timeline.partition.TombstoneShardSpec; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +public class DartTableInputSpecSlicerTest extends InitializedNullHandlingTest +{ + private static final String QUERY_ID = "abc"; + private static final String DATASOURCE = "test-ds"; + private static final String DATASOURCE_NONEXISTENT = "nonexistent-ds"; + private static final String PARTITION_DIM = "dim"; + private static final long BYTES_PER_SEGMENT = 1000; + + /** + * List of servers, with descending priority, so earlier servers are preferred by the {@link ServerSelector}. + * This makes tests deterministic. + */ + private static final List SERVERS = ImmutableList.of( + new DruidServerMetadata("no", "localhost:1001", null, 1, ServerType.HISTORICAL, "__default", 2), + new DruidServerMetadata("no", "localhost:1002", null, 1, ServerType.HISTORICAL, "__default", 1), + new DruidServerMetadata("no", "localhost:1003", null, 1, ServerType.REALTIME, "__default", 0) + ); + + /** + * Dart {@link WorkerId} derived from {@link #SERVERS}. + */ + private static final List WORKER_IDS = + SERVERS.stream() + .map(server -> new WorkerId("http", server.getHostAndPort(), QUERY_ID).toString()) + .collect(Collectors.toList()); + + /** + * Segment that is one of two in a range-partitioned time chunk. + */ + private static final DataSegment SEGMENT1 = new DataSegment( + DATASOURCE, + Intervals.of("2000/2001"), + "1", + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + new DimensionRangeShardSpec(ImmutableList.of(PARTITION_DIM), null, new StringTuple(new String[]{"foo"}), 0, 2), + null, + null, + BYTES_PER_SEGMENT + ); + + /** + * Segment that is one of two in a range-partitioned time chunk. + */ + private static final DataSegment SEGMENT2 = new DataSegment( + DATASOURCE, + Intervals.of("2000/2001"), + "1", + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + new DimensionRangeShardSpec(ImmutableList.of("dim"), new StringTuple(new String[]{"foo"}), null, 1, 2), + null, + null, + BYTES_PER_SEGMENT + ); + + /** + * Segment that is alone in a time chunk. It is not served by any server, and such segments are assigned to the + * existing servers round-robin. Because this is the only "not served by any server" segment, it should + * be assigned to the first server. + */ + private static final DataSegment SEGMENT3 = new DataSegment( + DATASOURCE, + Intervals.of("2001/2002"), + "1", + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + new NumberedShardSpec(0, 1), + null, + null, + BYTES_PER_SEGMENT + ); + + /** + * Segment that should be ignored because it's a tombstone. + */ + private static final DataSegment SEGMENT4 = new DataSegment( + DATASOURCE, + Intervals.of("2002/2003"), + "1", + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + TombstoneShardSpec.INSTANCE, + null, + null, + BYTES_PER_SEGMENT + ); + + /** + * Segment that should be ignored (for now) because it's realtime-only. + */ + private static final DataSegment SEGMENT5 = new DataSegment( + DATASOURCE, + Intervals.of("2003/2004"), + "1", + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList(), + new NumberedShardSpec(0, 1), + null, + null, + BYTES_PER_SEGMENT + ); + + /** + * Mapping of segment to servers (indexes in {@link #SERVERS}). + */ + private static final Map SEGMENT_SERVERS = + ImmutableMap.builder() + .put(SEGMENT1, IntList.of(0)) + .put(SEGMENT2, IntList.of(1)) + .put(SEGMENT3, IntLists.emptyList()) + .put(SEGMENT4, IntList.of(1)) + .put(SEGMENT5, IntList.of(2)) + .build(); + + private AutoCloseable mockCloser; + + /** + * Slicer under test. Built using {@link #timeline} and {@link #SERVERS}. + */ + private DartTableInputSpecSlicer slicer; + + /** + * Timeline built from {@link #SEGMENT_SERVERS} and {@link #SERVERS}. + */ + private VersionedIntervalTimeline timeline; + + /** + * Server view that uses {@link #timeline}. + */ + @Mock + private TimelineServerView serverView; + + @BeforeEach + void setUp() + { + mockCloser = MockitoAnnotations.openMocks(this); + slicer = DartTableInputSpecSlicer.createFromWorkerIds(WORKER_IDS, serverView); + + // Add all segments to the timeline, round-robin across the two servers. + timeline = new VersionedIntervalTimeline<>(Ordering.natural()); + for (Map.Entry entry : SEGMENT_SERVERS.entrySet()) { + final DataSegment dataSegment = entry.getKey(); + final IntList segmentServers = entry.getValue(); + final ServerSelector serverSelector = new ServerSelector( + dataSegment, + new HighestPriorityTierSelectorStrategy(new RandomServerSelectorStrategy()) + ); + for (int serverNumber : segmentServers) { + final DruidServerMetadata serverMetadata = SERVERS.get(serverNumber); + final DruidServer server = new DruidServer( + serverMetadata.getName(), + serverMetadata.getHostAndPort(), + serverMetadata.getHostAndTlsPort(), + serverMetadata.getMaxSize(), + serverMetadata.getType(), + serverMetadata.getTier(), + serverMetadata.getPriority() + ); + serverSelector.addServerAndUpdateSegment(new QueryableDruidServer<>(server, null), dataSegment); + } + timeline.add( + dataSegment.getInterval(), + dataSegment.getVersion(), + dataSegment.getShardSpec().createChunk(serverSelector) + ); + } + + Mockito.when(serverView.getDruidServerMetadatas()).thenReturn(SERVERS); + Mockito.when(serverView.getTimeline(new TableDataSource(DATASOURCE).getAnalysis())) + .thenReturn(Optional.of(timeline)); + Mockito.when(serverView.getTimeline(new TableDataSource(DATASOURCE_NONEXISTENT).getAnalysis())) + .thenReturn(Optional.empty()); + } + + @AfterEach + void tearDown() throws Exception + { + mockCloser.close(); + } + + @Test + public void test_sliceDynamic() + { + // This slicer cannot sliceDynamic. + + final TableInputSpec inputSpec = new TableInputSpec(DATASOURCE, null, null, null); + Assertions.assertFalse(slicer.canSliceDynamic(inputSpec)); + Assertions.assertThrows( + UnsupportedOperationException.class, + () -> slicer.sliceDynamic(inputSpec, 1, 1, 1) + ); + } + + @Test + public void test_sliceStatic_wholeTable_oneSlice() + { + // When 1 slice is requested, all segments are assigned to one server, even if that server doesn't actually + // currently serve those segments. + + final TableInputSpec inputSpec = new TableInputSpec(DATASOURCE, null, null, null); + final List inputSlices = slicer.sliceStatic(inputSpec, 1); + Assertions.assertEquals( + ImmutableList.of( + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT1.getInterval(), + SEGMENT1.getInterval(), + SEGMENT1.getVersion(), + SEGMENT1.getShardSpec().getPartitionNum() + ), + new RichSegmentDescriptor( + SEGMENT2.getInterval(), + SEGMENT2.getInterval(), + SEGMENT2.getVersion(), + SEGMENT2.getShardSpec().getPartitionNum() + ), + new RichSegmentDescriptor( + SEGMENT3.getInterval(), + SEGMENT3.getInterval(), + SEGMENT3.getVersion(), + SEGMENT3.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ) + ), + inputSlices + ); + } + + @Test + public void test_sliceStatic_wholeTable_twoSlices() + { + // When 2 slices are requested, we assign segments to the servers that have those segments. + + final TableInputSpec inputSpec = new TableInputSpec(DATASOURCE, null, null, null); + final List inputSlices = slicer.sliceStatic(inputSpec, 2); + Assertions.assertEquals( + ImmutableList.of( + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT1.getInterval(), + SEGMENT1.getInterval(), + SEGMENT1.getVersion(), + SEGMENT1.getShardSpec().getPartitionNum() + ), + new RichSegmentDescriptor( + SEGMENT3.getInterval(), + SEGMENT3.getInterval(), + SEGMENT3.getVersion(), + SEGMENT3.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ), + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT2.getInterval(), + SEGMENT2.getInterval(), + SEGMENT2.getVersion(), + SEGMENT2.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ) + ), + inputSlices + ); + } + + @Test + public void test_sliceStatic_wholeTable_threeSlices() + { + // When 3 slices are requested, only 2 are returned, because we only have two workers. + + final TableInputSpec inputSpec = new TableInputSpec(DATASOURCE, null, null, null); + final List inputSlices = slicer.sliceStatic(inputSpec, 3); + Assertions.assertEquals( + ImmutableList.of( + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT1.getInterval(), + SEGMENT1.getInterval(), + SEGMENT1.getVersion(), + SEGMENT1.getShardSpec().getPartitionNum() + ), + new RichSegmentDescriptor( + SEGMENT3.getInterval(), + SEGMENT3.getInterval(), + SEGMENT3.getVersion(), + SEGMENT3.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ), + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT2.getInterval(), + SEGMENT2.getInterval(), + SEGMENT2.getVersion(), + SEGMENT2.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ), + NilInputSlice.INSTANCE + ), + inputSlices + ); + } + + @Test + public void test_sliceStatic_nonexistentTable() + { + final TableInputSpec inputSpec = new TableInputSpec(DATASOURCE_NONEXISTENT, null, null, null); + final List inputSlices = slicer.sliceStatic(inputSpec, 1); + Assertions.assertEquals( + Collections.emptyList(), + inputSlices + ); + } + + @Test + public void test_sliceStatic_dimensionFilter_twoSlices() + { + // Filtered on a dimension that is used for range partitioning in 2000/2001, so one segment gets pruned out. + + final TableInputSpec inputSpec = new TableInputSpec( + DATASOURCE, + null, + new EqualityFilter(PARTITION_DIM, ColumnType.STRING, "abc", null), + null + ); + + final List inputSlices = slicer.sliceStatic(inputSpec, 2); + + Assertions.assertEquals( + ImmutableList.of( + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT1.getInterval(), + SEGMENT1.getInterval(), + SEGMENT1.getVersion(), + SEGMENT1.getShardSpec().getPartitionNum() + ), + new RichSegmentDescriptor( + SEGMENT3.getInterval(), + SEGMENT3.getInterval(), + SEGMENT3.getVersion(), + SEGMENT3.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ), + NilInputSlice.INSTANCE + ), + inputSlices + ); + } + + @Test + public void test_sliceStatic_timeFilter_twoSlices() + { + // Filtered on 2000/2001, so other segments get pruned out. + + final TableInputSpec inputSpec = new TableInputSpec( + DATASOURCE, + Collections.singletonList(Intervals.of("2000/P1Y")), + null, + null + ); + + final List inputSlices = slicer.sliceStatic(inputSpec, 2); + + Assertions.assertEquals( + ImmutableList.of( + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT1.getInterval(), + SEGMENT1.getInterval(), + SEGMENT1.getVersion(), + SEGMENT1.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ), + new SegmentsInputSlice( + DATASOURCE, + ImmutableList.of( + new RichSegmentDescriptor( + SEGMENT2.getInterval(), + SEGMENT2.getInterval(), + SEGMENT2.getVersion(), + SEGMENT2.getShardSpec().getPartitionNum() + ) + ), + ImmutableList.of() + ) + ), + inputSlices + ); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartWorkerManagerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartWorkerManagerTest.java new file mode 100644 index 000000000000..f4441c984e70 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/DartWorkerManagerTest.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import it.unimi.dsi.fastutil.ints.IntSet; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.error.DruidException; +import org.apache.druid.indexer.TaskState; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.msq.dart.worker.DartWorkerClient; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.msq.exec.WorkerManager; +import org.apache.druid.msq.exec.WorkerStats; +import org.junit.Assert; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public class DartWorkerManagerTest +{ + private static final List WORKERS = ImmutableList.of( + new WorkerId("http", "localhost:1001", "abc").toString(), + new WorkerId("http", "localhost:1002", "abc").toString() + ); + + private DartWorkerManager workerManager; + private AutoCloseable mockCloser; + + @Mock + private DartWorkerClient workerClient; + + @BeforeEach + public void setUp() + { + mockCloser = MockitoAnnotations.openMocks(this); + workerManager = new DartWorkerManager(WORKERS, workerClient); + } + + @AfterEach + public void tearDown() throws Exception + { + mockCloser.close(); + } + + @Test + public void test_getWorkerCount() + { + Assertions.assertEquals(0, workerManager.getWorkerCount().getPendingWorkerCount()); + Assertions.assertEquals(2, workerManager.getWorkerCount().getRunningWorkerCount()); + } + + @Test + public void test_getWorkerIds() + { + Assertions.assertEquals(WORKERS, workerManager.getWorkerIds()); + } + + @Test + public void test_getWorkerStats() + { + final Map> stats = workerManager.getWorkerStats(); + Assertions.assertEquals( + ImmutableMap.of( + 0, Collections.singletonList(new WorkerStats(WORKERS.get(0), TaskState.RUNNING, -1, -1)), + 1, Collections.singletonList(new WorkerStats(WORKERS.get(1), TaskState.RUNNING, -1, -1)) + ), + stats + ); + } + + @Test + public void test_getWorkerNumber() + { + Assertions.assertEquals(0, workerManager.getWorkerNumber(WORKERS.get(0))); + Assertions.assertEquals(1, workerManager.getWorkerNumber(WORKERS.get(1))); + Assertions.assertEquals(WorkerManager.UNKNOWN_WORKER_NUMBER, workerManager.getWorkerNumber("nonexistent")); + } + + @Test + public void test_isWorkerActive() + { + Assertions.assertTrue(workerManager.isWorkerActive(WORKERS.get(0))); + Assertions.assertTrue(workerManager.isWorkerActive(WORKERS.get(1))); + Assertions.assertFalse(workerManager.isWorkerActive("nonexistent")); + } + + @Test + public void test_launchWorkersIfNeeded() + { + workerManager.launchWorkersIfNeeded(0); // Does nothing, less than WORKERS.size() + workerManager.launchWorkersIfNeeded(1); // Does nothing, less than WORKERS.size() + workerManager.launchWorkersIfNeeded(2); // Does nothing, equal to WORKERS.size() + Assert.assertThrows( + DruidException.class, + () -> workerManager.launchWorkersIfNeeded(3) + ); + } + + @Test + public void test_waitForWorkers() + { + workerManager.launchWorkersIfNeeded(2); + workerManager.waitForWorkers(IntSet.of(0, 1)); // Returns immediately + } + + @Test + public void test_start_stop_noInterrupt() + { + Mockito.when(workerClient.stopWorker(WORKERS.get(0))) + .thenReturn(Futures.immediateFuture(null)); + Mockito.when(workerClient.stopWorker(WORKERS.get(1))) + .thenReturn(Futures.immediateFuture(null)); + + final ListenableFuture future = workerManager.start(); + workerManager.stop(false); + + // Ensure the future from start() resolves. + Assertions.assertNull(FutureUtils.getUnchecked(future, true)); + } + + @Test + public void test_start_stop_interrupt() + { + Mockito.when(workerClient.stopWorker(WORKERS.get(0))) + .thenReturn(Futures.immediateFuture(null)); + Mockito.when(workerClient.stopWorker(WORKERS.get(1))) + .thenReturn(Futures.immediateFuture(null)); + + final ListenableFuture future = workerManager.start(); + workerManager.stop(true); + + // Ensure the future from start() resolves. + Assertions.assertNull(FutureUtils.getUnchecked(future, true)); + } + + @Test + public void test_start_stop_interrupt_clientError() + { + Mockito.when(workerClient.stopWorker(WORKERS.get(0))) + .thenReturn(Futures.immediateFailedFuture(new ISE("stop failure"))); + Mockito.when(workerClient.stopWorker(WORKERS.get(1))) + .thenReturn(Futures.immediateFuture(null)); + + final ListenableFuture future = workerManager.start(); + workerManager.stop(true); + + // Ensure the future from start() resolves. + Assertions.assertNull(FutureUtils.getUnchecked(future, true)); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartQueryInfoTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartQueryInfoTest.java new file mode 100644 index 000000000000..980038723532 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartQueryInfoTest.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.jupiter.api.Test; + +public class DartQueryInfoTest +{ + @Test + public void test_equals() + { + EqualsVerifier.forClass(DartQueryInfo.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java new file mode 100644 index 000000000000..db3479178724 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/DartSqlResourceTest.java @@ -0,0 +1,757 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.util.concurrent.Futures; +import org.apache.druid.indexer.report.TaskReport; +import org.apache.druid.indexing.common.TaskLockType; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.java.util.common.IAE; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.jackson.JacksonUtils; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.msq.dart.controller.DartControllerRegistry; +import org.apache.druid.msq.dart.controller.sql.DartQueryMaker; +import org.apache.druid.msq.dart.controller.sql.DartSqlClient; +import org.apache.druid.msq.dart.controller.sql.DartSqlClients; +import org.apache.druid.msq.dart.controller.sql.DartSqlEngine; +import org.apache.druid.msq.dart.guice.DartControllerConfig; +import org.apache.druid.msq.exec.Controller; +import org.apache.druid.msq.indexing.error.CanceledFault; +import org.apache.druid.msq.indexing.error.InvalidNullByteFault; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.error.MSQFaultUtils; +import org.apache.druid.msq.indexing.report.MSQTaskReport; +import org.apache.druid.msq.test.MSQTestBase; +import org.apache.druid.msq.test.MSQTestControllerContext; +import org.apache.druid.query.DefaultQueryConfig; +import org.apache.druid.query.QueryContext; +import org.apache.druid.query.QueryContexts; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.QueryStackTests; +import org.apache.druid.server.ResponseContextConfig; +import org.apache.druid.server.initialization.ServerConfig; +import org.apache.druid.server.log.NoopRequestLogger; +import org.apache.druid.server.metrics.NoopServiceEmitter; +import org.apache.druid.server.mocks.MockAsyncContext; +import org.apache.druid.server.mocks.MockHttpServletResponse; +import org.apache.druid.server.security.AuthConfig; +import org.apache.druid.server.security.AuthenticationResult; +import org.apache.druid.server.security.ForbiddenException; +import org.apache.druid.sql.SqlLifecycleManager; +import org.apache.druid.sql.SqlStatementFactory; +import org.apache.druid.sql.SqlToolbox; +import org.apache.druid.sql.calcite.planner.CalciteRulesManager; +import org.apache.druid.sql.calcite.planner.CatalogResolver; +import org.apache.druid.sql.calcite.planner.PlannerConfig; +import org.apache.druid.sql.calcite.planner.PlannerFactory; +import org.apache.druid.sql.calcite.schema.DruidSchemaCatalog; +import org.apache.druid.sql.calcite.schema.NoopDruidSchemaManager; +import org.apache.druid.sql.calcite.util.CalciteTests; +import org.apache.druid.sql.calcite.util.QueryFrameworkUtils; +import org.apache.druid.sql.calcite.view.NoopViewManager; +import org.apache.druid.sql.hook.DruidHookDispatcher; +import org.apache.druid.sql.http.ResultFormat; +import org.apache.druid.sql.http.SqlQuery; +import org.hamcrest.CoreMatchers; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.core.Response; +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.MatcherAssert.assertThat; + +/** + * Functional test of {@link DartSqlResource}, {@link DartSqlEngine}, and {@link DartQueryMaker}. + * Other classes are mocked when possible. + */ +public class DartSqlResourceTest extends MSQTestBase +{ + private static final DruidNode SELF_NODE = new DruidNode("none", "localhost", false, 8080, -1, true, false); + private static final String AUTHENTICATOR_NAME = "authn"; + private static final int MAX_CONTROLLERS = 1; + + /** + * A user that is not a superuser. + * See {@link CalciteTests#TEST_AUTHORIZER_MAPPER} for how this user is mapped. + */ + private static final String REGULAR_USER_NAME = "regularUser"; + + /** + * A user that is not a superuser, and is different from {@link #REGULAR_USER_NAME}. + * See {@link CalciteTests#TEST_AUTHORIZER_MAPPER} for how this user is mapped. + */ + private static final String DIFFERENT_REGULAR_USER_NAME = "differentRegularUser"; + + /** + * Latch that cancellation tests can use to determine when a query is added to the {@link DartControllerRegistry}, + * and becomes cancelable. + */ + private final CountDownLatch controllerRegistered = new CountDownLatch(1); + + // Objects created in setUp() below this line. + + private DartSqlResource sqlResource; + private DartControllerRegistry controllerRegistry; + private ExecutorService controllerExecutor; + private AutoCloseable mockCloser; + + // Mocks below this line. + + /** + * Mock for {@link DartSqlClients}, which is used in tests of {@link DartSqlResource#doGetRunningQueries}. + */ + @Mock + private DartSqlClients dartSqlClients; + + /** + * Mock for {@link DartSqlClient}, which is used in tests of {@link DartSqlResource#doGetRunningQueries}. + */ + @Mock + private DartSqlClient dartSqlClient; + + /** + * Mock http request. + */ + @Mock + private HttpServletRequest httpServletRequest; + + /** + * Mock for test cases that need to make two requests. + */ + @Mock + private HttpServletRequest httpServletRequest2; + + @BeforeEach + void setUp() + { + mockCloser = MockitoAnnotations.openMocks(this); + + final DartSqlEngine engine = new DartSqlEngine( + queryId -> new MSQTestControllerContext( + objectMapper, + injector, + null /* not used in this test */, + workerMemoryParameters, + loadedSegmentsMetadata, + TaskLockType.APPEND, + QueryContext.empty() + ), + controllerRegistry = new DartControllerRegistry() + { + @Override + public void register(ControllerHolder holder) + { + super.register(holder); + controllerRegistered.countDown(); + } + }, + objectMapper.convertValue(ImmutableMap.of(), DartControllerConfig.class), + controllerExecutor = Execs.multiThreaded( + MAX_CONTROLLERS, + StringUtils.encodeForFormat(getClass().getSimpleName() + "-controller-exec") + ) + ); + + final DruidSchemaCatalog rootSchema = QueryFrameworkUtils.createMockRootSchema( + CalciteTests.INJECTOR, + queryFramework().conglomerate(), + queryFramework().walker(), + new PlannerConfig(), + new NoopViewManager(), + new NoopDruidSchemaManager(), + CalciteTests.TEST_AUTHORIZER_MAPPER, + CatalogResolver.NULL_RESOLVER + ); + + final PlannerFactory plannerFactory = new PlannerFactory( + rootSchema, + queryFramework().operatorTable(), + queryFramework().macroTable(), + PLANNER_CONFIG_DEFAULT, + CalciteTests.TEST_AUTHORIZER_MAPPER, + objectMapper, + CalciteTests.DRUID_SCHEMA_NAME, + new CalciteRulesManager(ImmutableSet.of()), + CalciteTests.createJoinableFactoryWrapper(), + CatalogResolver.NULL_RESOLVER, + new AuthConfig(), + new DruidHookDispatcher() + ); + + final SqlLifecycleManager lifecycleManager = new SqlLifecycleManager(); + final SqlToolbox toolbox = new SqlToolbox( + engine, + plannerFactory, + new NoopServiceEmitter(), + new NoopRequestLogger(), + QueryStackTests.DEFAULT_NOOP_SCHEDULER, + new DefaultQueryConfig(ImmutableMap.of()), + lifecycleManager + ); + + sqlResource = new DartSqlResource( + objectMapper, + CalciteTests.TEST_AUTHORIZER_MAPPER, + new SqlStatementFactory(toolbox), + controllerRegistry, + lifecycleManager, + dartSqlClients, + new ServerConfig() /* currently only used for error transform strategy */, + ResponseContextConfig.newConfig(false), + SELF_NODE, + new DefaultQueryConfig(ImmutableMap.of("foo", "bar")) + ); + + // Setup mocks + Mockito.when(dartSqlClients.getAllClients()).thenReturn(Collections.singletonList(dartSqlClient)); + } + + @AfterEach + void tearDown() throws Exception + { + mockCloser.close(); + + // shutdown(), not shutdownNow(), to ensure controllers stop timely on their own. + controllerExecutor.shutdown(); + + if (!controllerExecutor.awaitTermination(1, TimeUnit.MINUTES)) { + throw new IAE("controllerExecutor.awaitTermination() timed out"); + } + + // Ensure that controllerRegistry has nothing in it at the conclusion of each test. Verifies that controllers + // are fully cleaned up. + Assertions.assertEquals(0, controllerRegistry.getAllHolders().size(), "controllerRegistry.getAllHolders().size()"); + } + + @Test + public void test_getEnabled() + { + final Response response = sqlResource.doGetEnabled(httpServletRequest); + Assertions.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + } + + /** + * Test where a superuser calls {@link DartSqlResource#doGetRunningQueries} with selfOnly enabled. + */ + @Test + public void test_getRunningQueries_selfOnly_superUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(CalciteTests.TEST_SUPERUSER_NAME)); + + final ControllerHolder holder = setUpMockRunningQuery(REGULAR_USER_NAME); + + Assertions.assertEquals( + new GetQueriesResponse(Collections.singletonList(DartQueryInfo.fromControllerHolder(holder))), + sqlResource.doGetRunningQueries("", httpServletRequest) + ); + + controllerRegistry.deregister(holder); + } + + /** + * Test where {@link #REGULAR_USER_NAME} and {@link #DIFFERENT_REGULAR_USER_NAME} issue queries, and + * {@link #REGULAR_USER_NAME} calls {@link DartSqlResource#doGetRunningQueries} with selfOnly enabled. + */ + @Test + public void test_getRunningQueries_selfOnly_regularUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + + final ControllerHolder holder = setUpMockRunningQuery(REGULAR_USER_NAME); + final ControllerHolder holder2 = setUpMockRunningQuery(DIFFERENT_REGULAR_USER_NAME); + + // Regular users can see only their own queries, without authentication details. + Assertions.assertEquals(2, controllerRegistry.getAllHolders().size()); + Assertions.assertEquals( + new GetQueriesResponse( + Collections.singletonList(DartQueryInfo.fromControllerHolder(holder).withoutAuthenticationResult())), + sqlResource.doGetRunningQueries("", httpServletRequest) + ); + + controllerRegistry.deregister(holder); + controllerRegistry.deregister(holder2); + } + + /** + * Test where a superuser calls {@link DartSqlResource#doGetRunningQueries} with selfOnly disabled. + */ + @Test + public void test_getRunningQueries_global_superUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(CalciteTests.TEST_SUPERUSER_NAME)); + + // REGULAR_USER_NAME runs a query locally. + final ControllerHolder localHolder = setUpMockRunningQuery(REGULAR_USER_NAME); + + // DIFFERENT_REGULAR_USER_NAME runs a query remotely. + final DartQueryInfo remoteQueryInfo = new DartQueryInfo( + "sid", + "did2", + "SELECT 2", + AUTHENTICATOR_NAME, + DIFFERENT_REGULAR_USER_NAME, + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ); + Mockito.when(dartSqlClient.getRunningQueries(true)) + .thenReturn(Futures.immediateFuture(new GetQueriesResponse(Collections.singletonList(remoteQueryInfo)))); + + // With selfOnly = null, the endpoint returns both queries. + Assertions.assertEquals( + new GetQueriesResponse( + ImmutableList.of( + DartQueryInfo.fromControllerHolder(localHolder), + remoteQueryInfo + ) + ), + sqlResource.doGetRunningQueries(null, httpServletRequest) + ); + + controllerRegistry.deregister(localHolder); + } + + /** + * Test where a superuser calls {@link DartSqlResource#doGetRunningQueries} with selfOnly disabled, and where the + * remote server has a problem. + */ + @Test + public void test_getRunningQueries_global_remoteError_superUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(CalciteTests.TEST_SUPERUSER_NAME)); + + // REGULAR_USER_NAME runs a query locally. + final ControllerHolder localHolder = setUpMockRunningQuery(REGULAR_USER_NAME); + + // Remote call fails. + Mockito.when(dartSqlClient.getRunningQueries(true)) + .thenReturn(Futures.immediateFailedFuture(new IOException("something went wrong"))); + + // We only see local queries, because the remote call failed. (The entire call doesn't fail; we see what we + // were able to fetch.) + Assertions.assertEquals( + new GetQueriesResponse(ImmutableList.of(DartQueryInfo.fromControllerHolder(localHolder))), + sqlResource.doGetRunningQueries(null, httpServletRequest) + ); + + controllerRegistry.deregister(localHolder); + } + + /** + * Test where {@link #REGULAR_USER_NAME} and {@link #DIFFERENT_REGULAR_USER_NAME} issue queries, and + * {@link #REGULAR_USER_NAME} calls {@link DartSqlResource#doGetRunningQueries} with selfOnly disabled. + */ + @Test + public void test_getRunningQueries_global_regularUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + + // REGULAR_USER_NAME runs a query locally. + final ControllerHolder localHolder = setUpMockRunningQuery(REGULAR_USER_NAME); + + // DIFFERENT_REGULAR_USER_NAME runs a query remotely. + final DartQueryInfo remoteQueryInfo = new DartQueryInfo( + "sid", + "did2", + "SELECT 2", + AUTHENTICATOR_NAME, + DIFFERENT_REGULAR_USER_NAME, + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ); + Mockito.when(dartSqlClient.getRunningQueries(true)) + .thenReturn(Futures.immediateFuture(new GetQueriesResponse(Collections.singletonList(remoteQueryInfo)))); + + // The endpoint returns only the query issued by REGULAR_USER_NAME. + Assertions.assertEquals( + new GetQueriesResponse( + ImmutableList.of(DartQueryInfo.fromControllerHolder(localHolder).withoutAuthenticationResult())), + sqlResource.doGetRunningQueries(null, httpServletRequest) + ); + + controllerRegistry.deregister(localHolder); + } + + /** + * Test where {@link #REGULAR_USER_NAME} and {@link #DIFFERENT_REGULAR_USER_NAME} issue queries, and + * {@link #DIFFERENT_REGULAR_USER_NAME} calls {@link DartSqlResource#doGetRunningQueries} with selfOnly disabled. + */ + @Test + public void test_getRunningQueries_global_differentRegularUser() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(DIFFERENT_REGULAR_USER_NAME)); + + // REGULAR_USER_NAME runs a query locally. + final ControllerHolder holder = setUpMockRunningQuery(REGULAR_USER_NAME); + + // DIFFERENT_REGULAR_USER_NAME runs a query remotely. + final DartQueryInfo remoteQueryInfo = new DartQueryInfo( + "sid", + "did2", + "SELECT 2", + AUTHENTICATOR_NAME, + DIFFERENT_REGULAR_USER_NAME, + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ); + Mockito.when(dartSqlClient.getRunningQueries(true)) + .thenReturn(Futures.immediateFuture(new GetQueriesResponse(Collections.singletonList(remoteQueryInfo)))); + + // The endpoint returns only the query issued by DIFFERENT_REGULAR_USER_NAME. + Assertions.assertEquals( + new GetQueriesResponse(ImmutableList.of(remoteQueryInfo.withoutAuthenticationResult())), + sqlResource.doGetRunningQueries(null, httpServletRequest) + ); + + controllerRegistry.deregister(holder); + } + + @Test + public void test_doPost_regularUser() + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + final SqlQuery sqlQuery = new SqlQuery( + "SELECT 1 + 1", + ResultFormat.ARRAY, + false, + false, + false, + Collections.emptyMap(), + Collections.emptyList() + ); + + Assertions.assertNull(sqlResource.doPost(sqlQuery, httpServletRequest)); + Assertions.assertEquals(Response.Status.OK.getStatusCode(), asyncResponse.getStatus()); + Assertions.assertEquals("[[2]]\n", StringUtils.fromUtf8(asyncResponse.baos.toByteArray())); + } + + @Test + public void test_doPost_regularUser_forbidden() + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + final SqlQuery sqlQuery = new SqlQuery( + StringUtils.format("SELECT * FROM \"%s\"", CalciteTests.FORBIDDEN_DATASOURCE), + ResultFormat.ARRAY, + false, + false, + false, + Collections.emptyMap(), + Collections.emptyList() + ); + + Assertions.assertThrows( + ForbiddenException.class, + () -> sqlResource.doPost(sqlQuery, httpServletRequest) + ); + } + + @Test + public void test_doPost_regularUser_runtimeError() throws IOException + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + final SqlQuery sqlQuery = new SqlQuery( + "SELECT U&'\\0000'", + ResultFormat.ARRAY, + false, + false, + false, + Collections.emptyMap(), + Collections.emptyList() + ); + + Assertions.assertNull(sqlResource.doPost(sqlQuery, httpServletRequest)); + Assertions.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), asyncResponse.getStatus()); + + final Map e = objectMapper.readValue( + asyncResponse.baos.toByteArray(), + JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT + ); + + Assertions.assertEquals("InvalidNullByte", e.get("errorCode")); + Assertions.assertEquals("RUNTIME_FAILURE", e.get("category")); + assertThat((String) e.get("errorMessage"), CoreMatchers.startsWith("InvalidNullByte: ")); + } + + @Test + public void test_doPost_regularUser_fullReport() throws Exception + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + final SqlQuery sqlQuery = new SqlQuery( + "SELECT 1 + 1", + ResultFormat.ARRAY, + false, + false, + false, + ImmutableMap.of(DartSqlEngine.CTX_FULL_REPORT, true), + Collections.emptyList() + ); + + Assertions.assertNull(sqlResource.doPost(sqlQuery, httpServletRequest)); + Assertions.assertEquals(Response.Status.OK.getStatusCode(), asyncResponse.getStatus()); + + final List> reportMaps = objectMapper.readValue( + asyncResponse.baos.toByteArray(), + new TypeReference>>() {} + ); + + Assertions.assertEquals(1, reportMaps.size()); + final MSQTaskReport report = + (MSQTaskReport) Iterables.getOnlyElement(Iterables.getOnlyElement(reportMaps)).get(MSQTaskReport.REPORT_KEY); + final List results = report.getPayload().getResults().getResults(); + + Assertions.assertEquals(1, results.size()); + Assertions.assertArrayEquals(new Object[]{2}, results.get(0)); + } + + @Test + public void test_doPost_regularUser_runtimeError_fullReport() throws Exception + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + final SqlQuery sqlQuery = new SqlQuery( + "SELECT U&'\\0000'", + ResultFormat.ARRAY, + false, + false, + false, + ImmutableMap.of(DartSqlEngine.CTX_FULL_REPORT, true), + Collections.emptyList() + ); + + Assertions.assertNull(sqlResource.doPost(sqlQuery, httpServletRequest)); + Assertions.assertEquals(Response.Status.OK.getStatusCode(), asyncResponse.getStatus()); + + final List> reportMaps = objectMapper.readValue( + asyncResponse.baos.toByteArray(), + new TypeReference>>() {} + ); + + Assertions.assertEquals(1, reportMaps.size()); + final MSQTaskReport report = + (MSQTaskReport) Iterables.getOnlyElement(Iterables.getOnlyElement(reportMaps)).get(MSQTaskReport.REPORT_KEY); + final MSQErrorReport errorReport = report.getPayload().getStatus().getErrorReport(); + Assertions.assertNotNull(errorReport); + assertThat(errorReport.getFault(), CoreMatchers.instanceOf(InvalidNullByteFault.class)); + } + + @Test + public void test_doPost_regularUser_thenCancelQuery() throws Exception + { + run_test_doPost_regularUser_fullReport_thenCancelQuery(false); + } + + @Test + public void test_doPost_regularUser_fullReport_thenCancelQuery() throws Exception + { + run_test_doPost_regularUser_fullReport_thenCancelQuery(true); + } + + /** + * Helper for {@link #test_doPost_regularUser_thenCancelQuery()} and + * {@link #test_doPost_regularUser_fullReport_thenCancelQuery()}. We need to do cancellation tests with and + * without the "fullReport" parameter, because {@link DartQueryMaker} has a separate pathway for each one. + */ + private void run_test_doPost_regularUser_fullReport_thenCancelQuery(final boolean fullReport) throws Exception + { + final MockAsyncContext asyncContext = new MockAsyncContext(); + final MockHttpServletResponse asyncResponse = new MockHttpServletResponse(); + asyncContext.response = asyncResponse; + + // POST SQL query request. + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + Mockito.when(httpServletRequest.startAsync()) + .thenReturn(asyncContext); + + // Cancellation request. + Mockito.when(httpServletRequest2.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + + // Block up the controllerExecutor so the controller runs long enough to cancel it. + final Future sleepFuture = controllerExecutor.submit(() -> { + try { + Thread.sleep(3_600_000); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + final String sqlQueryId = UUID.randomUUID().toString(); + final SqlQuery sqlQuery = new SqlQuery( + "SELECT 1 + 1", + ResultFormat.ARRAY, + false, + false, + false, + ImmutableMap.of(QueryContexts.CTX_SQL_QUERY_ID, sqlQueryId, DartSqlEngine.CTX_FULL_REPORT, fullReport), + Collections.emptyList() + ); + + final ExecutorService doPostExec = Execs.singleThreaded("do-post-exec-%s"); + final Future doPostFuture; + try { + // Run doPost in a separate thread. There are now three threads: + // 1) The controllerExecutor thread, which is blocked up by sleepFuture. + // 2) The doPostExec thread, which has a doPost in there, blocking on controllerExecutor. + // 3) The current main test thread, which continues on and which will issue the cancellation request. + doPostFuture = doPostExec.submit(() -> sqlResource.doPost(sqlQuery, httpServletRequest)); + controllerRegistered.await(); + + // Issue cancellation request. + final Response cancellationResponse = sqlResource.cancelQuery(sqlQueryId, httpServletRequest2); + Assertions.assertEquals(Response.Status.ACCEPTED.getStatusCode(), cancellationResponse.getStatus()); + + // Now that the cancellation request has been accepted, we can cancel the sleepFuture and allow the + // controller to be canceled. + sleepFuture.cancel(true); + doPostExec.shutdown(); + } + catch (Throwable e) { + doPostExec.shutdownNow(); + throw e; + } + + if (!doPostExec.awaitTermination(1, TimeUnit.MINUTES)) { + throw new ISE("doPost timed out"); + } + + // Wait for the SQL POST to come back. + Assertions.assertNull(doPostFuture.get()); + Assertions.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), asyncResponse.getStatus()); + + // Ensure MSQ fault (CanceledFault) is properly translated to a DruidException and then properly serialized. + final Map e = objectMapper.readValue( + asyncResponse.baos.toByteArray(), + JacksonUtils.TYPE_REFERENCE_MAP_STRING_OBJECT + ); + Assertions.assertEquals("Canceled", e.get("errorCode")); + Assertions.assertEquals("CANCELED", e.get("category")); + Assertions.assertEquals( + MSQFaultUtils.generateMessageWithErrorCode(CanceledFault.instance()), + e.get("errorMessage") + ); + } + + @Test + public void test_cancelQuery_regularUser_unknownQuery() + { + Mockito.when(httpServletRequest.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .thenReturn(makeAuthenticationResult(REGULAR_USER_NAME)); + + final Response cancellationResponse = sqlResource.cancelQuery("nonexistent", httpServletRequest); + Assertions.assertEquals(Response.Status.NOT_FOUND.getStatusCode(), cancellationResponse.getStatus()); + } + + /** + * Add a mock {@link ControllerHolder} to {@link #controllerRegistry}, with a query run by the given user. + * Used by methods that test {@link DartSqlResource#doGetRunningQueries}. + * + * @return the mock holder + */ + private ControllerHolder setUpMockRunningQuery(final String identity) + { + final Controller controller = Mockito.mock(Controller.class); + Mockito.when(controller.queryId()).thenReturn("did_" + identity); + + final AuthenticationResult authenticationResult = makeAuthenticationResult(identity); + final ControllerHolder holder = + new ControllerHolder(controller, null, "sid", "SELECT 1", authenticationResult, DateTimes.of("2000")); + + controllerRegistry.register(holder); + return holder; + } + + /** + * Create an {@link AuthenticationResult} with {@link AuthenticationResult#getAuthenticatedBy()} set to + * {@link #AUTHENTICATOR_NAME}. + */ + private static AuthenticationResult makeAuthenticationResult(final String identity) + { + return new AuthenticationResult(identity, null, AUTHENTICATOR_NAME, Collections.emptyMap()); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponseTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponseTest.java new file mode 100644 index 000000000000..7b43c863c9d1 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/http/GetQueriesResponseTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.http; + +import com.fasterxml.jackson.databind.ObjectMapper; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.segment.TestHelper; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +public class GetQueriesResponseTest +{ + @Test + public void test_serde() throws Exception + { + final ObjectMapper jsonMapper = TestHelper.JSON_MAPPER; + final GetQueriesResponse response = new GetQueriesResponse( + Collections.singletonList( + new DartQueryInfo( + "xyz", + "abc", + "SELECT 1", + "auth", + "anon", + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ) + ) + ); + final GetQueriesResponse response2 = + jsonMapper.readValue(jsonMapper.writeValueAsBytes(response), GetQueriesResponse.class); + Assertions.assertEquals(response, response2); + } + + @Test + public void test_equals() + { + EqualsVerifier.forClass(GetQueriesResponse.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/messages/ControllerMessageTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/messages/ControllerMessageTest.java new file mode 100644 index 000000000000..427faf4aee6f --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/messages/ControllerMessageTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.messages; + +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.ObjectMapper; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.msq.guice.MSQIndexingModule; +import org.apache.druid.msq.indexing.error.MSQErrorReport; +import org.apache.druid.msq.indexing.error.UnknownFault; +import org.apache.druid.msq.kernel.StageId; +import org.apache.druid.msq.statistics.PartialKeyStatisticsInformation; +import org.apache.druid.segment.TestHelper; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.Collections; + +public class ControllerMessageTest +{ + private static final StageId STAGE_ID = StageId.fromString("xyz_2"); + private ObjectMapper objectMapper; + + @BeforeEach + public void setUp() + { + objectMapper = TestHelper.JSON_MAPPER.copy(); + objectMapper.enable(JsonParser.Feature.STRICT_DUPLICATE_DETECTION); + objectMapper.registerModules(new MSQIndexingModule().getJacksonModules()); + } + + @Test + public void testSerde() throws IOException + { + final PartialKeyStatisticsInformation partialKeyStatisticsInformation = + new PartialKeyStatisticsInformation(Collections.emptySet(), false, 0); + + assertSerde(new PartialKeyStatistics(STAGE_ID, 1, partialKeyStatisticsInformation)); + assertSerde(new DoneReadingInput(STAGE_ID, 1)); + assertSerde(new ResultsComplete(STAGE_ID, 1, "foo")); + assertSerde( + new WorkerError( + STAGE_ID.getQueryId(), + MSQErrorReport.fromFault("task", null, null, UnknownFault.forMessage("oops")) + ) + ); + assertSerde( + new WorkerWarning( + STAGE_ID.getQueryId(), + Collections.singletonList(MSQErrorReport.fromFault("task", null, null, UnknownFault.forMessage("oops"))) + ) + ); + } + + @Test + public void testEqualsAndHashCode() + { + EqualsVerifier.forClass(PartialKeyStatistics.class).usingGetClass().verify(); + EqualsVerifier.forClass(DoneReadingInput.class).usingGetClass().verify(); + EqualsVerifier.forClass(ResultsComplete.class).usingGetClass().verify(); + EqualsVerifier.forClass(WorkerError.class).usingGetClass().verify(); + EqualsVerifier.forClass(WorkerWarning.class).usingGetClass().verify(); + } + + private void assertSerde(final ControllerMessage message) throws IOException + { + final String json = objectMapper.writeValueAsString(message); + final ControllerMessage message2 = objectMapper.readValue(json, ControllerMessage.class); + Assertions.assertEquals(message, message2, json); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImplTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImplTest.java new file mode 100644 index 000000000000..19a4eaf0b151 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/controller/sql/DartSqlClientImplTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.controller.sql; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.jackson.DefaultObjectMapper; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.msq.dart.controller.ControllerHolder; +import org.apache.druid.msq.dart.controller.http.DartQueryInfo; +import org.apache.druid.msq.dart.controller.http.GetQueriesResponse; +import org.apache.druid.rpc.MockServiceClient; +import org.apache.druid.rpc.RequestBuilder; +import org.jboss.netty.handler.codec.http.HttpMethod; +import org.jboss.netty.handler.codec.http.HttpResponseStatus; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.MediaType; + +public class DartSqlClientImplTest +{ + private ObjectMapper jsonMapper; + private MockServiceClient serviceClient; + private DartSqlClient dartSqlClient; + + @BeforeEach + public void setup() + { + jsonMapper = new DefaultObjectMapper(); + serviceClient = new MockServiceClient(); + dartSqlClient = new DartSqlClientImpl(serviceClient, jsonMapper); + } + + @AfterEach + public void tearDown() + { + serviceClient.verify(); + } + + @Test + public void test_getMessages_all() throws Exception + { + final GetQueriesResponse getQueriesResponse = new GetQueriesResponse( + ImmutableList.of( + new DartQueryInfo( + "sid", + "did", + "SELECT 1", + "", + "", + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ) + ) + ); + + serviceClient.expectAndRespond( + new RequestBuilder(HttpMethod.GET, "/"), + HttpResponseStatus.OK, + ImmutableMap.of(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON), + jsonMapper.writeValueAsBytes(getQueriesResponse) + ); + + final ListenableFuture result = dartSqlClient.getRunningQueries(false); + Assertions.assertEquals(getQueriesResponse, result.get()); + } + + @Test + public void test_getMessages_selfOnly() throws Exception + { + final GetQueriesResponse getQueriesResponse = new GetQueriesResponse( + ImmutableList.of( + new DartQueryInfo( + "sid", + "did", + "SELECT 1", + "", + "", + DateTimes.of("2000"), + ControllerHolder.State.RUNNING.toString() + ) + ) + ); + + serviceClient.expectAndRespond( + new RequestBuilder(HttpMethod.GET, "/?selfOnly"), + HttpResponseStatus.OK, + ImmutableMap.of(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON), + jsonMapper.writeValueAsBytes(getQueriesResponse) + ); + + final ListenableFuture result = dartSqlClient.getRunningQueries(true); + Assertions.assertEquals(getQueriesResponse, result.get()); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartQueryableSegmentTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartQueryableSegmentTest.java new file mode 100644 index 000000000000..b53a397dae81 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartQueryableSegmentTest.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.jupiter.api.Test; + +public class DartQueryableSegmentTest +{ + @Test + public void test_equals() + { + EqualsVerifier.forClass(DartQueryableSegment.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartWorkerRunnerTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartWorkerRunnerTest.java new file mode 100644 index 000000000000..1f152b74049f --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/DartWorkerRunnerTest.java @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.google.common.util.concurrent.SettableFuture; +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.FileUtils; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.msq.dart.DartResourcePermissionMapper; +import org.apache.druid.msq.dart.worker.http.GetWorkersResponse; +import org.apache.druid.msq.exec.Worker; +import org.apache.druid.msq.indexing.error.CanceledFault; +import org.apache.druid.msq.indexing.error.MSQException; +import org.apache.druid.query.QueryContext; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.security.AuthorizerMapper; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.junit.internal.matchers.ThrowableMessageMatcher; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Path; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * Functional test of {@link DartWorkerRunner}. + */ +public class DartWorkerRunnerTest +{ + private static final int MAX_WORKERS = 1; + private static final String QUERY_ID = "abc"; + private static final WorkerId WORKER_ID = new WorkerId("http", "localhost:8282", QUERY_ID); + private static final String CONTROLLER_SERVER_HOST = "localhost:8081"; + private static final DiscoveryDruidNode CONTROLLER_DISCOVERY_NODE = + new DiscoveryDruidNode( + new DruidNode("no", "localhost", false, 8081, -1, true, false), + NodeRole.BROKER, + Collections.emptyMap() + ); + + private final SettableFuture workerRun = SettableFuture.create(); + + private ExecutorService workerExec; + private DartWorkerRunner workerRunner; + private AutoCloseable mockCloser; + + @TempDir + public Path temporaryFolder; + + @Mock + private DartWorkerFactory workerFactory; + + @Mock + private Worker worker; + + @Mock + private DruidNodeDiscoveryProvider discoveryProvider; + + @Mock + private DruidNodeDiscovery discovery; + + @Mock + private AuthorizerMapper authorizerMapper; + + @Captor + private ArgumentCaptor discoveryListener; + + @BeforeEach + public void setUp() + { + mockCloser = MockitoAnnotations.openMocks(this); + workerRunner = new DartWorkerRunner( + workerFactory, + workerExec = Execs.multiThreaded(MAX_WORKERS, "worker-exec-%s"), + discoveryProvider, + new DartResourcePermissionMapper(), + authorizerMapper, + temporaryFolder.toFile() + ); + + // "discoveryProvider" provides "discovery". + Mockito.when(discoveryProvider.getForNodeRole(NodeRole.BROKER)).thenReturn(discovery); + + // "workerFactory" builds "worker". + Mockito.when( + workerFactory.build( + QUERY_ID, + CONTROLLER_SERVER_HOST, + temporaryFolder.toFile(), + QueryContext.empty() + ) + ).thenReturn(worker); + + // "worker.run()" exits when "workerRun" resolves. + Mockito.doAnswer(invocation -> { + workerRun.get(); + return null; + }).when(worker).run(); + + // "worker.stop()" sets "workerRun" to a cancellation error. + Mockito.doAnswer(invocation -> { + workerRun.setException(new MSQException(CanceledFault.instance())); + return null; + }).when(worker).stop(); + + // "worker.controllerFailed()" sets "workerRun" to an error. + Mockito.doAnswer(invocation -> { + workerRun.setException(new ISE("Controller failed")); + return null; + }).when(worker).controllerFailed(); + + // "worker.awaitStop()" waits for "workerRun". It does not throw an exception, just like WorkerImpl.awaitStop. + Mockito.doAnswer(invocation -> { + try { + workerRun.get(); + } + catch (Throwable e) { + // Suppress + } + return null; + }).when(worker).awaitStop(); + + // "worker.id()" returns WORKER_ID. + Mockito.when(worker.id()).thenReturn(WORKER_ID.toString()); + + // Start workerRunner, capture listener in "discoveryListener". + workerRunner.start(); + Mockito.verify(discovery).registerListener(discoveryListener.capture()); + } + + @AfterEach + public void tearDown() throws Exception + { + workerExec.shutdown(); + workerRunner.stop(); + mockCloser.close(); + + if (!workerExec.awaitTermination(1, TimeUnit.MINUTES)) { + throw new ISE("workerExec did not terminate within timeout"); + } + } + + @Test + public void test_getWorkersResponse_empty() + { + final GetWorkersResponse workersResponse = workerRunner.getWorkersResponse(); + Assertions.assertEquals(new GetWorkersResponse(Collections.emptyList()), workersResponse); + } + + @Test + public void test_getWorkerResource_notFound() + { + Assertions.assertNull(workerRunner.getWorkerResource("nonexistent")); + } + + @Test + public void test_createAndCleanTempDirectory() throws IOException + { + workerRunner.stop(); + + // Create an empty directory "x". + FileUtils.mkdirp(new File(temporaryFolder.toFile(), "x")); + Assertions.assertArrayEquals( + new File[]{new File(temporaryFolder.toFile(), "x")}, + temporaryFolder.toFile().listFiles() + ); + + // Run "createAndCleanTempDirectory", which will delete it. + workerRunner.createAndCleanTempDirectory(); + Assertions.assertArrayEquals(new File[]{}, temporaryFolder.toFile().listFiles()); + } + + @Test + public void test_startWorker_controllerNotActive() + { + final DruidException e = Assertions.assertThrows( + DruidException.class, + () -> workerRunner.startWorker("abc", CONTROLLER_SERVER_HOST, QueryContext.empty()) + ); + + MatcherAssert.assertThat( + e, + ThrowableMessageMatcher.hasMessage(CoreMatchers.containsString( + "Received startWorker request for unknown controller")) + ); + } + + @Test + public void test_stopWorker_nonexistent() + { + // Nothing happens when we do this. Just verifying an exception isn't thrown. + workerRunner.stopWorker("nonexistent"); + } + + @Test + public void test_startWorker() + { + // Activate controller. + discoveryListener.getValue().nodesAdded(Collections.singletonList(CONTROLLER_DISCOVERY_NODE)); + + // Start the worker twice (startWorker is idempotent; nothing special happens the second time). + final Worker workerFromStart = workerRunner.startWorker(QUERY_ID, CONTROLLER_SERVER_HOST, QueryContext.empty()); + final Worker workerFromStart2 = workerRunner.startWorker(QUERY_ID, CONTROLLER_SERVER_HOST, QueryContext.empty()); + Assertions.assertSame(worker, workerFromStart); + Assertions.assertSame(worker, workerFromStart2); + + // Worker should enter the GetWorkersResponse. + final GetWorkersResponse workersResponse = workerRunner.getWorkersResponse(); + Assertions.assertEquals(1, workersResponse.getWorkers().size()); + Assertions.assertEquals(QUERY_ID, workersResponse.getWorkers().get(0).getDartQueryId()); + Assertions.assertEquals(CONTROLLER_SERVER_HOST, workersResponse.getWorkers().get(0).getControllerHost()); + Assertions.assertEquals(WORKER_ID, workersResponse.getWorkers().get(0).getWorkerId()); + + // Worker should have a resource. + Assertions.assertNotNull(workerRunner.getWorkerResource(QUERY_ID)); + } + + @Test + @Timeout(value = 1, unit = TimeUnit.MINUTES) + public void test_startWorker_thenRemoveController() throws InterruptedException + { + // Activate controller. + discoveryListener.getValue().nodesAdded(Collections.singletonList(CONTROLLER_DISCOVERY_NODE)); + + // Start the worker. + final Worker workerFromStart = workerRunner.startWorker(QUERY_ID, CONTROLLER_SERVER_HOST, QueryContext.empty()); + Assertions.assertSame(worker, workerFromStart); + Assertions.assertEquals(1, workerRunner.getWorkersResponse().getWorkers().size()); + + // Deactivate controller. + discoveryListener.getValue().nodesRemoved(Collections.singletonList(CONTROLLER_DISCOVERY_NODE)); + + // Worker should go away. + workerRunner.awaitQuerySet(Set::isEmpty); + Assertions.assertEquals(0, workerRunner.getWorkersResponse().getWorkers().size()); + } + + @Test + @Timeout(value = 1, unit = TimeUnit.MINUTES) + public void test_startWorker_thenStopWorker() throws InterruptedException + { + // Activate controller. + discoveryListener.getValue().nodesAdded(Collections.singletonList(CONTROLLER_DISCOVERY_NODE)); + + // Start the worker. + final Worker workerFromStart = workerRunner.startWorker(QUERY_ID, CONTROLLER_SERVER_HOST, QueryContext.empty()); + Assertions.assertSame(worker, workerFromStart); + Assertions.assertEquals(1, workerRunner.getWorkersResponse().getWorkers().size()); + + // Stop that worker. + workerRunner.stopWorker(QUERY_ID); + + // Worker should go away. + workerRunner.awaitQuerySet(Set::isEmpty); + Assertions.assertEquals(0, workerRunner.getWorkersResponse().getWorkers().size()); + } + + @Test + @Timeout(value = 1, unit = TimeUnit.MINUTES) + public void test_startWorker_thenStopRunner() throws InterruptedException + { + // Activate controller. + discoveryListener.getValue().nodesAdded(Collections.singletonList(CONTROLLER_DISCOVERY_NODE)); + + // Start the worker. + final Worker workerFromStart = workerRunner.startWorker(QUERY_ID, CONTROLLER_SERVER_HOST, QueryContext.empty()); + Assertions.assertSame(worker, workerFromStart); + Assertions.assertEquals(1, workerRunner.getWorkersResponse().getWorkers().size()); + + // Stop runner. + workerRunner.stop(); + + // Worker should go away. + workerRunner.awaitQuerySet(Set::isEmpty); + Assertions.assertEquals(0, workerRunner.getWorkersResponse().getWorkers().size()); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/WorkerIdTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/WorkerIdTest.java new file mode 100644 index 000000000000..e4f74a0250f6 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/WorkerIdTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker; + +import com.fasterxml.jackson.databind.ObjectMapper; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.segment.TestHelper; +import org.apache.druid.server.DruidNode; +import org.apache.druid.server.coordination.DruidServerMetadata; +import org.apache.druid.server.coordination.ServerType; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.IOException; + +public class WorkerIdTest +{ + @Test + public void test_fromString() + { + Assertions.assertEquals( + new WorkerId("https", "local-host:8100", "xyz"), + WorkerId.fromString("https:local-host:8100:xyz") + ); + } + + @Test + public void test_fromDruidNode() + { + Assertions.assertEquals( + new WorkerId("https", "local-host:8100", "xyz"), + WorkerId.fromDruidNode(new DruidNode("none", "local-host", false, 8200, 8100, true, true), "xyz") + ); + } + + @Test + public void test_fromDruidServerMetadata() + { + Assertions.assertEquals( + new WorkerId("https", "local-host:8100", "xyz"), + WorkerId.fromDruidServerMetadata( + new DruidServerMetadata("none", "local-host:8200", "local-host:8100", 1, ServerType.HISTORICAL, "none", 0), + "xyz" + ) + ); + } + + @Test + public void test_toString() + { + Assertions.assertEquals( + "https:local-host:8100:xyz", + new WorkerId("https", "local-host:8100", "xyz").toString() + ); + } + + @Test + public void test_getters() + { + final WorkerId workerId = new WorkerId("https", "local-host:8100", "xyz"); + Assertions.assertEquals("https", workerId.getScheme()); + Assertions.assertEquals("local-host:8100", workerId.getHostAndPort()); + Assertions.assertEquals("xyz", workerId.getQueryId()); + Assertions.assertEquals("https://local-host:8100/druid/dart-worker/workers/xyz", workerId.toUri().toString()); + } + + @Test + public void test_serde() throws IOException + { + final ObjectMapper objectMapper = TestHelper.JSON_MAPPER; + final WorkerId workerId = new WorkerId("https", "localhost:8100", "xyz"); + final WorkerId workerId2 = objectMapper.readValue(objectMapper.writeValueAsBytes(workerId), WorkerId.class); + Assertions.assertEquals(workerId, workerId2); + } + + @Test + public void test_equals() + { + EqualsVerifier.forClass(WorkerId.class) + .usingGetClass() + .withNonnullFields("fullString") + .withIgnoredFields("scheme", "hostAndPort", "queryId") + .verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfoTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfoTest.java new file mode 100644 index 000000000000..74cd8a28915a --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/DartWorkerInfoTest.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker.http; + +import nl.jqno.equalsverifier.EqualsVerifier; +import org.junit.jupiter.api.Test; + +public class DartWorkerInfoTest +{ + @Test + public void test_equals() + { + EqualsVerifier.forClass(DartWorkerInfo.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponseTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponseTest.java new file mode 100644 index 000000000000..f516077a5754 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/dart/worker/http/GetWorkersResponseTest.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.dart.worker.http; + +import com.fasterxml.jackson.databind.ObjectMapper; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.java.util.common.DateTimes; +import org.apache.druid.msq.dart.worker.WorkerId; +import org.apache.druid.segment.TestHelper; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Collections; + +public class GetWorkersResponseTest +{ + @Test + public void test_serde() throws Exception + { + final ObjectMapper jsonMapper = TestHelper.JSON_MAPPER; + final GetWorkersResponse response = new GetWorkersResponse( + Collections.singletonList( + new DartWorkerInfo( + "xyz", + WorkerId.fromString("http:localhost:8100:xyz"), + "localhost:8101", + DateTimes.of("2000") + ) + ) + ); + final GetWorkersResponse response2 = + jsonMapper.readValue(jsonMapper.writeValueAsBytes(response), GetWorkersResponse.class); + Assertions.assertEquals(response, response2); + } + + @Test + public void test_equals() + { + EqualsVerifier.forClass(GetWorkersResponse.class).usingGetClass().verify(); + } +} diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java index e1ce49d82923..89018596be2c 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestBase.java @@ -320,6 +320,7 @@ public class MSQTestBase extends BaseCalciteQueryTest protected File localFileStorageDir; protected LocalFileStorageConnector localFileStorageConnector; private static final Logger log = new Logger(MSQTestBase.class); + protected Injector injector; protected ObjectMapper objectMapper; protected MSQTestOverlordServiceClient indexingServiceClient; protected MSQTestTaskActionClient testTaskActionClient; @@ -530,7 +531,7 @@ public String getFormatString() binder -> binder.bind(Bouncer.class).toInstance(new Bouncer(1)) ); // adding node role injection to the modules, since CliPeon would also do that through run method - Injector injector = new CoreInjectorBuilder(new StartupInjectorBuilder().build(), ImmutableSet.of(NodeRole.PEON)) + injector = new CoreInjectorBuilder(new StartupInjectorBuilder().build(), ImmutableSet.of(NodeRole.PEON)) .addAll(modules) .build(); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java index 970d873c96c8..4dadeae5bc10 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestControllerContext.java @@ -56,7 +56,6 @@ import org.apache.druid.msq.exec.WorkerStorageParameters; import org.apache.druid.msq.indexing.IndexerControllerContext; import org.apache.druid.msq.indexing.IndexerTableInputSpecSlicer; -import org.apache.druid.msq.indexing.MSQControllerTask; import org.apache.druid.msq.indexing.MSQSpec; import org.apache.druid.msq.indexing.MSQWorkerTask; import org.apache.druid.msq.indexing.MSQWorkerTaskLauncher; @@ -108,8 +107,8 @@ public class MSQTestControllerContext implements ControllerContext private Controller controller; private final WorkerMemoryParameters workerMemoryParameters; + private final TaskLockType taskLockType; private final QueryContext queryContext; - private final MSQControllerTask controllerTask; public MSQTestControllerContext( ObjectMapper mapper, @@ -117,7 +116,8 @@ public MSQTestControllerContext( TaskActionClient taskActionClient, WorkerMemoryParameters workerMemoryParameters, List loadedSegments, - MSQControllerTask controllerTask + TaskLockType taskLockType, + QueryContext queryContext ) { this.mapper = mapper; @@ -137,8 +137,8 @@ public MSQTestControllerContext( .collect(Collectors.toList()) ); this.workerMemoryParameters = workerMemoryParameters; - this.controllerTask = controllerTask; - this.queryContext = controllerTask.getQuerySpec().getQuery().context(); + this.taskLockType = taskLockType; + this.queryContext = queryContext; } OverlordClient overlordClient = new NoopOverlordClient() @@ -329,7 +329,7 @@ public TaskActionClient taskActionClient() @Override public TaskLockType taskLockType() { - return controllerTask.getTaskLockType(); + return taskLockType; } @Override diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java index 6a7db8aa5b63..b35c074fa060 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestOverlordServiceClient.java @@ -103,7 +103,8 @@ public ListenableFuture runTask(String taskId, Object taskObject) taskActionClient, workerMemoryParameters, loadedSegmentMetadata, - cTask + cTask.getTaskLockType(), + cTask.getQuerySpec().getQuery().context() ); inMemoryControllerTask.put(cTask.getId(), cTask); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java index ffd7c67ca2d6..4c7ccd72efd0 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java @@ -35,10 +35,12 @@ import java.io.InputStream; import java.util.Arrays; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; public class MSQTestWorkerClient implements WorkerClient { private final Map inMemoryWorkers; + private final AtomicBoolean closed = new AtomicBoolean(); public MSQTestWorkerClient(Map inMemoryWorkers) { @@ -141,6 +143,8 @@ public ListenableFuture fetchChannelData( @Override public void close() { - inMemoryWorkers.forEach((k, v) -> v.stop()); + if (closed.compareAndSet(false, true)) { + inMemoryWorkers.forEach((k, v) -> v.stop()); + } } } diff --git a/processing/src/main/java/org/apache/druid/common/guava/FutureBox.java b/processing/src/main/java/org/apache/druid/common/guava/FutureBox.java new file mode 100644 index 000000000000..3e92706aa028 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/common/guava/FutureBox.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.common.guava; + +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.java.util.common.concurrent.Execs; + +import java.io.Closeable; +import java.util.Set; + +/** + * Box for tracking pending futures. Allows cancellation of all pending futures. + */ +public class FutureBox implements Closeable +{ + /** + * Currently-outstanding futures. These are tracked so they can be canceled in {@link #close()}. + */ + private final Set> pendingFutures = Sets.newConcurrentHashSet(); + + private volatile boolean stopped; + + /** + * Returns the number of currently-pending futures. + */ + public int pendingCount() + { + return pendingFutures.size(); + } + + /** + * Adds a future to the box. + * If {@link #close()} had previously been called, the future is immediately canceled. + */ + public ListenableFuture register(final ListenableFuture future) + { + pendingFutures.add(future); + future.addListener(() -> pendingFutures.remove(future), Execs.directExecutor()); + + // If "stop" was called while we were creating this future, cancel it prior to returning it. + if (stopped) { + future.cancel(false); + } + + return future; + } + + /** + * Closes the box, canceling all currently-pending futures. + */ + @Override + public void close() + { + stopped = true; + for (ListenableFuture future : pendingFutures) { + future.cancel(false); // Ignore return value + } + } +} diff --git a/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java b/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java new file mode 100644 index 000000000000..6d27abb42739 --- /dev/null +++ b/processing/src/main/java/org/apache/druid/io/LimitedOutputStream.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.io; + +import org.apache.druid.error.DruidException; +import org.apache.druid.java.util.common.IOE; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.function.Function; + +/** + * An {@link OutputStream} that limits how many bytes can be written. Throws {@link IOException} if the limit + * is exceeded. + */ +public class LimitedOutputStream extends OutputStream +{ + private final OutputStream out; + private final long limit; + private final Function exceptionMessageFn; + long written; + + /** + * Create a bytes-limited output stream. + * + * @param out output stream to wrap + * @param limit bytes limit + * @param exceptionMessageFn function for generating an exception message for an {@link IOException}, given the limit. + */ + public LimitedOutputStream(OutputStream out, long limit, Function exceptionMessageFn) + { + this.out = out; + this.limit = limit; + this.exceptionMessageFn = exceptionMessageFn; + + if (limit < 0) { + throw DruidException.defensive("Limit[%s] must be greater than or equal to zero", limit); + } + } + + @Override + public void write(int b) throws IOException + { + plus(1); + out.write(b); + } + + @Override + public void write(byte[] b) throws IOException + { + plus(b.length); + out.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException + { + plus(len); + out.write(b, off, len); + } + + @Override + public void flush() throws IOException + { + out.flush(); + } + + @Override + public void close() throws IOException + { + out.close(); + } + + private void plus(final int n) throws IOException + { + written += n; + if (written > limit) { + throw new IOE(exceptionMessageFn.apply(limit)); + } + } +} diff --git a/processing/src/test/java/org/apache/druid/common/guava/FutureBoxTest.java b/processing/src/test/java/org/apache/druid/common/guava/FutureBoxTest.java new file mode 100644 index 000000000000..7428f94fa71a --- /dev/null +++ b/processing/src/test/java/org/apache/druid/common/guava/FutureBoxTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.common.guava; + +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import org.junit.Test; +import org.junit.jupiter.api.Assertions; + +import java.util.concurrent.ExecutionException; + +public class FutureBoxTest +{ + @Test + public void test_immediateFutures() throws Exception + { + try (final FutureBox box = new FutureBox()) { + Assertions.assertEquals("a", box.register(Futures.immediateFuture("a")).get()); + Assertions.assertThrows( + ExecutionException.class, + () -> box.register(Futures.immediateFailedFuture(new RuntimeException())).get() + ); + Assertions.assertTrue(box.register(Futures.immediateCancelledFuture()).isCancelled()); + Assertions.assertEquals(0, box.pendingCount()); + } + } + + @Test + public void test_register_thenStop() + { + final FutureBox box = new FutureBox(); + final SettableFuture settableFuture = SettableFuture.create(); + + final ListenableFuture retVal = box.register(settableFuture); + Assertions.assertSame(retVal, settableFuture); + Assertions.assertEquals(1, box.pendingCount()); + + box.close(); + Assertions.assertEquals(0, box.pendingCount()); + + Assertions.assertTrue(settableFuture.isCancelled()); + } + + @Test + public void test_stop_thenRegister() + { + final FutureBox box = new FutureBox(); + final SettableFuture settableFuture = SettableFuture.create(); + + box.close(); + final ListenableFuture retVal = box.register(settableFuture); + + Assertions.assertSame(retVal, settableFuture); + Assertions.assertEquals(0, box.pendingCount()); + Assertions.assertTrue(settableFuture.isCancelled()); + } +} diff --git a/processing/src/test/java/org/apache/druid/io/LimitedOutputStreamTest.java b/processing/src/test/java/org/apache/druid/io/LimitedOutputStreamTest.java new file mode 100644 index 000000000000..a11b63149710 --- /dev/null +++ b/processing/src/test/java/org/apache/druid/io/LimitedOutputStreamTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.io; + +import org.apache.druid.java.util.common.StringUtils; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.junit.Assert; +import org.junit.Test; +import org.junit.internal.matchers.ThrowableMessageMatcher; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +public class LimitedOutputStreamTest +{ + @Test + public void test_limitZero() throws IOException + { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final OutputStream stream = + new LimitedOutputStream(baos, 0, LimitedOutputStreamTest::makeErrorMessage)) { + final IOException e = Assert.assertThrows( + IOException.class, + () -> stream.write('b') + ); + + MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.equalTo("Limit[0] exceeded"))); + } + } + + @Test + public void test_limitThree() throws IOException + { + try (final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + final OutputStream stream = + new LimitedOutputStream(baos, 3, LimitedOutputStreamTest::makeErrorMessage)) { + stream.write('a'); + stream.write(new byte[]{'b'}); + stream.write(new byte[]{'c'}, 0, 1); + final IOException e = Assert.assertThrows( + IOException.class, + () -> stream.write('d') + ); + + MatcherAssert.assertThat(e, ThrowableMessageMatcher.hasMessage(CoreMatchers.equalTo("Limit[3] exceeded"))); + } + } + + private static String makeErrorMessage(final long limit) + { + return StringUtils.format("Limit[%d] exceeded", limit); + } +} diff --git a/server/src/main/java/org/apache/druid/client/BrokerServerView.java b/server/src/main/java/org/apache/druid/client/BrokerServerView.java index 2cb2bec03b59..f2eb62db0208 100644 --- a/server/src/main/java/org/apache/druid/client/BrokerServerView.java +++ b/server/src/main/java/org/apache/druid/client/BrokerServerView.java @@ -44,6 +44,7 @@ import org.apache.druid.timeline.VersionedIntervalTimeline; import org.apache.druid.timeline.partition.PartitionChunk; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -398,6 +399,19 @@ private void runTimelineCallbacks(final Function getDruidServerMetadatas() + { + // Override default implementation for better performance. + final List retVal = new ArrayList<>(clients.size()); + + for (final QueryableDruidServer server : clients.values()) { + retVal.add(server.getServer().getMetadata()); + } + + return retVal; + } + @Override public List getDruidServers() { diff --git a/server/src/main/java/org/apache/druid/client/TimelineServerView.java b/server/src/main/java/org/apache/druid/client/TimelineServerView.java index 9a2b7b767755..9c6ee608e1f4 100644 --- a/server/src/main/java/org/apache/druid/client/TimelineServerView.java +++ b/server/src/main/java/org/apache/druid/client/TimelineServerView.java @@ -27,6 +27,7 @@ import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.TimelineLookup; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.concurrent.Executor; @@ -45,10 +46,23 @@ public interface TimelineServerView extends ServerView * * @throws IllegalStateException if 'analysis' does not represent a scan-based datasource of a single table */ - Optional> getTimeline(DataSourceAnalysis analysis); + > Optional getTimeline(DataSourceAnalysis analysis); /** - * Returns a list of {@link ImmutableDruidServer} + * Returns a snapshot of the current set of server metadata. + */ + default List getDruidServerMetadatas() + { + final List druidServers = getDruidServers(); + final List metadatas = new ArrayList<>(druidServers.size()); + for (final ImmutableDruidServer druidServer : druidServers) { + metadatas.add(druidServer.getMetadata()); + } + return metadatas; + } + + /** + * Returns a snapshot of the current servers, their metadata, and their inventory. */ List getDruidServers(); diff --git a/server/src/main/java/org/apache/druid/discovery/DataServerClient.java b/server/src/main/java/org/apache/druid/discovery/DataServerClient.java index ce7ac325b62b..ce3d62ca91b5 100644 --- a/server/src/main/java/org/apache/druid/discovery/DataServerClient.java +++ b/server/src/main/java/org/apache/druid/discovery/DataServerClient.java @@ -35,7 +35,7 @@ import org.apache.druid.java.util.http.client.response.StatusResponseHolder; import org.apache.druid.query.Query; import org.apache.druid.query.context.ResponseContext; -import org.apache.druid.rpc.FixedSetServiceLocator; +import org.apache.druid.rpc.FixedServiceLocator; import org.apache.druid.rpc.RequestBuilder; import org.apache.druid.rpc.ServiceClient; import org.apache.druid.rpc.ServiceClientFactory; @@ -71,7 +71,7 @@ public DataServerClient( { this.serviceClient = serviceClientFactory.makeClient( serviceLocation.getHost(), - FixedSetServiceLocator.forServiceLocation(serviceLocation), + new FixedServiceLocator(serviceLocation), StandardRetryPolicy.noRetries() ); this.serviceLocation = serviceLocation; diff --git a/server/src/main/java/org/apache/druid/messages/MessageBatch.java b/server/src/main/java/org/apache/druid/messages/MessageBatch.java new file mode 100644 index 000000000000..51209ff6d243 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/MessageBatch.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; +import org.apache.druid.messages.client.MessageRelay; +import org.apache.druid.messages.server.MessageRelayResource; +import org.apache.druid.messages.server.Outbox; + +import java.util.List; +import java.util.Objects; + +/** + * A batch of messages collected by {@link MessageRelay} from a remote {@link Outbox} through + * {@link MessageRelayResource#httpGetMessagesFromOutbox}. + */ +public class MessageBatch +{ + private final List messages; + private final long epoch; + private final long startWatermark; + + @JsonCreator + public MessageBatch( + @JsonProperty("messages") final List messages, + @JsonProperty("epoch") final long epoch, + @JsonProperty("watermark") final long startWatermark + ) + { + this.messages = messages; + this.epoch = epoch; + this.startWatermark = startWatermark; + } + + /** + * The messages. + */ + @JsonProperty + public List getMessages() + { + return messages; + } + + /** + * Epoch, which is associated with a specific instance of {@link Outbox}. + */ + @JsonProperty + @JsonInclude(JsonInclude.Include.NON_DEFAULT) + public long getEpoch() + { + return epoch; + } + + /** + * Watermark, an incrementing message ID that enables clients and servers to stay in sync, and enables + * acknowledging of messages. + */ + @JsonProperty("watermark") + @JsonInclude(JsonInclude.Include.NON_DEFAULT) + public long getStartWatermark() + { + return startWatermark; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + MessageBatch that = (MessageBatch) o; + return epoch == that.epoch && startWatermark == that.startWatermark && Objects.equals(messages, that.messages); + } + + @Override + public int hashCode() + { + return Objects.hash(messages, epoch, startWatermark); + } + + @Override + public String toString() + { + return "MessageBatch{" + + "messages=" + messages + + ", epoch=" + epoch + + ", startWatermark=" + startWatermark + + '}'; + } +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageListener.java b/server/src/main/java/org/apache/druid/messages/client/MessageListener.java new file mode 100644 index 000000000000..6711c418f812 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageListener.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import org.apache.druid.server.DruidNode; + +/** + * Listener for messages received by clients. + */ +public interface MessageListener +{ + /** + * Called when a server is added. + * + * @param node server node + */ + void serverAdded(DruidNode node); + + /** + * Called when a message is received. Should not throw exceptions. If this method does throw an exception, + * the exception is logged and the message is acknowledged anyway. + * + * @param message the message that was received + */ + void messageReceived(MessageType message); + + /** + * Called when a server is removed. + * + * @param node server node + */ + void serverRemoved(DruidNode node); +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageRelay.java b/server/src/main/java/org/apache/druid/messages/client/MessageRelay.java new file mode 100644 index 000000000000..deda87c7d23d --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageRelay.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.java.util.common.ISE; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.server.MessageRelayResource; +import org.apache.druid.rpc.ServiceClosedException; +import org.apache.druid.server.DruidNode; + +import java.io.Closeable; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Relays run on clients, and receive messages from a server. + * Uses {@link MessageRelayClient} to communicate with the {@link MessageRelayResource} on a server. + * that flows upstream + */ +public class MessageRelay implements Closeable +{ + private static final Logger log = new Logger(MessageRelay.class); + + /** + * Value to provide for epoch on the initial call to {@link MessageRelayClient#getMessages(String, long, long)}. + */ + public static final long INIT = -1; + + private final String selfHost; + private final DruidNode serverNode; + private final MessageRelayClient client; + private final AtomicBoolean closed = new AtomicBoolean(false); + private final Collector collector; + + public MessageRelay( + final String selfHost, + final DruidNode serverNode, + final MessageRelayClient client, + final MessageListener listener + ) + { + this.selfHost = selfHost; + this.serverNode = serverNode; + this.client = client; + this.collector = new Collector(listener); + } + + /** + * Start the {@link Collector}. + */ + public void start() + { + collector.start(); + } + + /** + * Stop the {@link Collector}. + */ + @Override + public void close() + { + if (closed.compareAndSet(false, true)) { + collector.stop(); + } + } + + /** + * Retrieves messages that are being sent to this client and hands them to {@link #listener}. + */ + private class Collector + { + private final MessageListener listener; + private final AtomicLong epoch = new AtomicLong(INIT); + private final AtomicLong watermark = new AtomicLong(INIT); + private final AtomicReference> currentCall = new AtomicReference<>(); + + public Collector(final MessageListener listener) + { + this.listener = listener; + } + + private void start() + { + if (!watermark.compareAndSet(INIT, 0)) { + throw new ISE("Already started"); + } + + listener.serverAdded(serverNode); + issueNextGetMessagesCall(); + } + + private void issueNextGetMessagesCall() + { + if (closed.get()) { + return; + } + + final long theEpoch = epoch.get(); + final long theWatermark = watermark.get(); + + log.debug( + "Getting messages from server[%s] for client[%s] (current state: epoch[%s] watermark[%s]).", + serverNode.getHostAndPortToUse(), + selfHost, + theEpoch, + theWatermark + ); + + final ListenableFuture> future = client.getMessages(selfHost, theEpoch, theWatermark); + + if (!currentCall.compareAndSet(null, future)) { + log.error( + "Fatal error: too many outgoing calls to server[%s] for client[%s] " + + "(current state: epoch[%s] watermark[%s]). Closing collector.", + serverNode.getHostAndPortToUse(), + selfHost, + theEpoch, + theWatermark + ); + + close(); + return; + } + + Futures.addCallback( + future, + new FutureCallback>() + { + @Override + public void onSuccess(final MessageBatch result) + { + log.debug("Received message batch: %s", result); + currentCall.compareAndSet(future, null); + final long endWatermark = result.getStartWatermark() + result.getMessages().size(); + if (theEpoch == INIT) { + epoch.set(result.getEpoch()); + watermark.set(endWatermark); + } else if (epoch.get() != result.getEpoch() + || !watermark.compareAndSet(result.getStartWatermark(), endWatermark)) { + // We don't expect to see this unless there is somehow another collector running with the same + // clientHost. If the unexpected happens, log it and close the collector. It will stay, doing + // nothing, in the MessageCollectors map until it is removed by the discovery listener. + log.error( + "Incorrect epoch + watermark from server[%s] for client[%s] " + + "(expected[%s:%s] but got[%s:%s]). " + + "Closing collector.", + serverNode.getHostAndPortToUse(), + selfHost, + theEpoch, + theWatermark, + result.getEpoch(), + result.getStartWatermark() + ); + + close(); + return; + } + + for (final MessageType message : result.getMessages()) { + try { + listener.messageReceived(message); + } + catch (Throwable e) { + log.warn( + e, + "Failed to handle message[%s] from server[%s] for client[%s].", + message, + selfHost, + serverNode.getHostAndPortToUse() + ); + } + } + + issueNextGetMessagesCall(); + } + + @Override + public void onFailure(final Throwable e) + { + currentCall.compareAndSet(future, null); + if (!(e instanceof CancellationException) && !(e instanceof ServiceClosedException)) { + // We don't expect to see any other errors, since we use an unlimited retry policy for clients. If the + // unexpected happens, log it and close the collector. It will stay, doing nothing, in the + // MessageCollectors map until it is removed by the discovery listener. + log.error( + e, + "Fatal error contacting server[%s] for client[%s] " + + "(current state: epoch[%s] watermark[%s]). " + + "Closing collector.", + serverNode.getHostAndPortToUse(), + selfHost, + theEpoch, + theWatermark + ); + } + + close(); + } + }, + Execs.directExecutor() + ); + } + + public void stop() + { + final ListenableFuture future = currentCall.getAndSet(null); + if (future != null) { + future.cancel(true); + } + + try { + listener.serverRemoved(serverNode); + } + catch (Throwable e) { + log.warn(e, "Failed to close server[%s]", serverNode.getHostAndPortToUse()); + } + } + } +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageRelayClient.java b/server/src/main/java/org/apache/druid/messages/client/MessageRelayClient.java new file mode 100644 index 000000000000..fad228f7b5f0 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageRelayClient.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.server.MessageRelayResource; + +/** + * Client for {@link MessageRelayResource}. + */ +public interface MessageRelayClient +{ + /** + * Get the next batch of messages from an outbox. + * + * @param clientHost which outbox to retrieve messages from. Each clientHost has its own outbox. + * @param epoch outbox epoch, or {@link MessageRelay#INIT} if this is the first call from the collector. + * @param startWatermark outbox message watermark to retrieve from. + * + * @return future that resolves to the next batch of messages + * + * @see MessageRelayResource#httpGetMessagesFromOutbox http endpoint this method calls + */ + ListenableFuture> getMessages(String clientHost, long epoch, long startWatermark); +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageRelayClientImpl.java b/server/src/main/java/org/apache/druid/messages/client/MessageRelayClientImpl.java new file mode 100644 index 000000000000..140bd45e1af4 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageRelayClientImpl.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.java.util.common.StringUtils; +import org.apache.druid.java.util.common.jackson.JacksonUtils; +import org.apache.druid.java.util.http.client.response.BytesFullResponseHandler; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.rpc.RequestBuilder; +import org.apache.druid.rpc.ServiceClient; +import org.eclipse.jetty.http.HttpStatus; +import org.jboss.netty.handler.codec.http.HttpMethod; + +import java.util.Collections; + +/** + * Production implementation of {@link MessageRelayClient}. + */ +public class MessageRelayClientImpl implements MessageRelayClient +{ + private final ServiceClient serviceClient; + private final ObjectMapper smileMapper; + private final JavaType inMessageBatchType; + + public MessageRelayClientImpl( + final ServiceClient serviceClient, + final ObjectMapper smileMapper, + final Class inMessageClass + ) + { + this.serviceClient = serviceClient; + this.smileMapper = smileMapper; + this.inMessageBatchType = smileMapper.getTypeFactory().constructParametricType(MessageBatch.class, inMessageClass); + } + + @Override + public ListenableFuture> getMessages( + final String clientHost, + final long epoch, + final long startWatermark + ) + { + final String path = StringUtils.format( + "/outbox/%s/messages?epoch=%d&watermark=%d", + StringUtils.urlEncode(clientHost), + epoch, + startWatermark + ); + + return FutureUtils.transform( + serviceClient.asyncRequest( + new RequestBuilder(HttpMethod.GET, path), + new BytesFullResponseHandler() + ), + holder -> { + if (holder.getResponse().getStatus().getCode() == HttpStatus.NO_CONTENT_204) { + return new MessageBatch<>(Collections.emptyList(), epoch, startWatermark); + } else { + return JacksonUtils.readValue(smileMapper, holder.getContent(), inMessageBatchType); + } + } + ); + } +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageRelayFactory.java b/server/src/main/java/org/apache/druid/messages/client/MessageRelayFactory.java new file mode 100644 index 000000000000..b647b9e4b6a2 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageRelayFactory.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import org.apache.druid.server.DruidNode; + +/** + * Factory for creating new message relays. Used by {@link MessageRelays}. + */ +public interface MessageRelayFactory +{ + MessageRelay newRelay(DruidNode druidNode); +} diff --git a/server/src/main/java/org/apache/druid/messages/client/MessageRelays.java b/server/src/main/java/org/apache/druid/messages/client/MessageRelays.java new file mode 100644 index 000000000000..e7af8fc51b55 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/client/MessageRelays.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.guice.ManageLifecycle; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.lifecycle.LifecycleStart; +import org.apache.druid.java.util.common.lifecycle.LifecycleStop; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.server.DruidNode; +import org.apache.druid.utils.CloseableUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +/** + * Manages a fleet of {@link MessageRelay}, one for each server discovered by a {@link DruidNodeDiscoveryProvider}. + */ +@ManageLifecycle +public class MessageRelays +{ + private static final Logger log = new Logger(MessageRelays.class); + + @GuardedBy("serverRelays") + private final Map> serverRelays = new HashMap<>(); + private final Supplier discoverySupplier; + private final MessageRelayFactory messageRelayFactory; + private final MessageRelaysListener listener; + + private volatile DruidNodeDiscovery discovery; + + public MessageRelays( + final Supplier discoverySupplier, + final MessageRelayFactory messageRelayFactory + ) + { + this.discoverySupplier = discoverySupplier; + this.messageRelayFactory = messageRelayFactory; + this.listener = new MessageRelaysListener(); + } + + @LifecycleStart + public void start() + { + discovery = discoverySupplier.get(); + discovery.registerListener(listener); + } + + @LifecycleStop + public void stop() + { + if (discovery != null) { + discovery.removeListener(listener); + discovery = null; + } + + synchronized (serverRelays) { + try { + CloseableUtils.closeAll(serverRelays.values()); + } + catch (IOException e) { + throw new RuntimeException(e); + } + finally { + serverRelays.clear(); + } + } + } + + /** + * Discovery listener. Creates and tears down individual host relays. + */ + class MessageRelaysListener implements DruidNodeDiscovery.Listener + { + @Override + public void nodesAdded(final Collection nodes) + { + synchronized (serverRelays) { + for (final DiscoveryDruidNode node : nodes) { + final DruidNode druidNode = node.getDruidNode(); + + serverRelays.computeIfAbsent(druidNode.getHostAndPortToUse(), ignored -> { + final MessageRelay relay = messageRelayFactory.newRelay(druidNode); + relay.start(); + return relay; + }); + } + } + } + + @Override + public void nodesRemoved(final Collection nodes) + { + final List>> removed = new ArrayList<>(); + + synchronized (serverRelays) { + for (final DiscoveryDruidNode node : nodes) { + final DruidNode druidNode = node.getDruidNode(); + final String druidHost = druidNode.getHostAndPortToUse(); + final MessageRelay relay = serverRelays.remove(druidHost); + if (relay != null) { + removed.add(Pair.of(druidHost, relay)); + } + } + } + + for (final Pair> relay : removed) { + try { + relay.rhs.close(); + } + catch (Throwable e) { + log.noStackTrace().warn(e, "Could not close relay for server[%s]. Dropping.", relay.lhs); + } + } + } + } +} diff --git a/server/src/main/java/org/apache/druid/messages/package-info.java b/server/src/main/java/org/apache/druid/messages/package-info.java new file mode 100644 index 000000000000..9eb36d1e181c --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/package-info.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * Message relays provide a mechanism to send messages from server to client using long polling. The messages are + * sent in order, with acknowledgements from client to server when a message has been successfully delivered. + * + * This is useful when there is some need for some "downstream" servers to send low-latency messages to some + * "upstream" server, but where establishing connections from downstream servers to upstream servers would not be + * desirable. This is typically done when upstream servers want to keep state in-memory that is updated incrementally + * by downstream servers, and where there may be lots of instances of downstream servers. + * + * This structure has two main benefits. First, it prevents upstream servers from being overwhelmed by connections + * from downstream servers. Second, it allows upstream servers to drive the updates of their own state, and better + * handle events like restarts and leader changes. + * + * On the downstream (server) side, messages are placed into an {@link org.apache.druid.messages.server.Outbox} + * and served by a {@link org.apache.druid.messages.server.MessageRelayResource}. + * + * On the upstream (client) side, messages are retrieved by {@link org.apache.druid.messages.client.MessageRelays} + * using {@link org.apache.druid.messages.client.MessageRelayClient}. + * + * This is currently used by Dart (multi-stage-query engine running on Brokers and Historicals) to implement + * worker-to-controller messages. In the future it may also be used to implement + * {@link org.apache.druid.server.coordination.ChangeRequestHttpSyncer}. + */ + +package org.apache.druid.messages; diff --git a/server/src/main/java/org/apache/druid/messages/server/MessageRelayMonitor.java b/server/src/main/java/org/apache/druid/messages/server/MessageRelayMonitor.java new file mode 100644 index 000000000000..1126f273ccaa --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/server/MessageRelayMonitor.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.server; + +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.DruidNodeDiscoveryProvider; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.java.util.common.lifecycle.LifecycleStart; + +import java.util.Collection; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Code that runs on message servers, to monitor their clients. When a client vanishes, its outbox is reset using + * {@link Outbox#resetOutbox(String)}. + */ +public class MessageRelayMonitor +{ + private final DruidNodeDiscoveryProvider discoveryProvider; + private final Outbox outbox; + private final NodeRole clientRole; + + public MessageRelayMonitor( + final DruidNodeDiscoveryProvider discoveryProvider, + final Outbox outbox, + final NodeRole clientRole + ) + { + this.discoveryProvider = discoveryProvider; + this.outbox = outbox; + this.clientRole = clientRole; + } + + @LifecycleStart + public void start() + { + discoveryProvider.getForNodeRole(clientRole).registerListener(new ClientListener()); + } + + /** + * Listener that cancels work associated with clients that have gone away. + */ + private class ClientListener implements DruidNodeDiscovery.Listener + { + @Override + public void nodesAdded(Collection nodes) + { + // Nothing to do. Although, perhaps it would make sense to *set up* an outbox here. (Currently, outboxes are + // created on-demand as they receive messages.) + } + + @Override + public void nodesRemoved(Collection nodes) + { + final Set hostsRemoved = + nodes.stream().map(node -> node.getDruidNode().getHostAndPortToUse()).collect(Collectors.toSet()); + + for (final String clientHost : hostsRemoved) { + outbox.resetOutbox(clientHost); + } + } + } +} diff --git a/server/src/main/java/org/apache/druid/messages/server/MessageRelayResource.java b/server/src/main/java/org/apache/druid/messages/server/MessageRelayResource.java new file mode 100644 index 000000000000..f8e771d378c7 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/server/MessageRelayResource.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.server; + +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.java.util.common.concurrent.Execs; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.client.MessageListener; +import org.apache.druid.messages.client.MessageRelayClient; + +import javax.servlet.AsyncContext; +import javax.servlet.AsyncEvent; +import javax.servlet.AsyncListener; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.PathParam; +import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; +import java.io.IOException; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Server-side resource for message relaying. Wraps an {@link Outbox} and {@link MessageListener}. + * The client for this resource is {@link MessageRelayClient}. + */ +public class MessageRelayResource +{ + private static final Logger log = new Logger(MessageRelayResource.class); + private static final long GET_MESSAGES_TIMEOUT = 30_000L; + + /** + * Outbox for messages sent from this server. + */ + private final Outbox outbox; + + /** + * Message relay protocol uses Smile. + */ + private final ObjectMapper smileMapper; + + /** + * Type of {@link MessageBatch} of {@link MessageType}. + */ + private final JavaType batchType; + + public MessageRelayResource( + final Outbox outbox, + final ObjectMapper smileMapper, + final Class messageClass + ) + { + this.outbox = outbox; + this.smileMapper = smileMapper; + this.batchType = smileMapper.getTypeFactory().constructParametricType(MessageBatch.class, messageClass); + } + + /** + * Retrieve messages from the outbox for a particular client, as a {@link MessageBatch} in Smile format. + * The messages are retrieved from {@link Outbox#getMessages(String, long, long)}. + * + * This is a long-polling async method, using {@link AsyncContext} to wait up to {@link #GET_MESSAGES_TIMEOUT} for + * messages to appear in the outbox. + * + * @return HTTP 200 with Smile response with messages on success; HTTP 204 (No Content) if no messages were put in + * the outbox before the timeout {@link #GET_MESSAGES_TIMEOUT} elapsed + * + * @see Outbox#getMessages(String, long, long) for more details on the API + */ + @GET + @Path("/outbox/{clientHost}/messages") + public Void httpGetMessagesFromOutbox( + @PathParam("clientHost") final String clientHost, + @QueryParam("epoch") final Long epoch, + @QueryParam("watermark") final Long watermark, + @Context final HttpServletRequest req + ) throws IOException + { + if (epoch == null || watermark == null || clientHost == null || clientHost.isEmpty()) { + AsyncContext asyncContext = req.startAsync(); + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.sendError(HttpServletResponse.SC_BAD_REQUEST); + asyncContext.complete(); + return null; + } + + final AtomicBoolean didRespond = new AtomicBoolean(); + final ListenableFuture> batchFuture = outbox.getMessages(clientHost, epoch, watermark); + final AsyncContext asyncContext = req.startAsync(); + asyncContext.setTimeout(GET_MESSAGES_TIMEOUT); + asyncContext.addListener( + new AsyncListener() + { + @Override + public void onComplete(AsyncEvent event) + { + } + + @Override + public void onTimeout(AsyncEvent event) + { + if (didRespond.compareAndSet(false, true)) { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.setStatus(HttpServletResponse.SC_NO_CONTENT); + event.getAsyncContext().complete(); + batchFuture.cancel(true); + } + } + + @Override + public void onError(AsyncEvent event) + { + } + + @Override + public void onStartAsync(AsyncEvent event) + { + } + } + ); + + // Save these items, since "req" becomes inaccessible in future exception handlers. + final String remoteAddr = req.getRemoteAddr(); + final String requestURI = req.getRequestURI(); + + Futures.addCallback( + batchFuture, + new FutureCallback>() + { + @Override + public void onSuccess(MessageBatch result) + { + if (didRespond.compareAndSet(false, true)) { + log.debug("Sending message batch: %s", result); + try { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.setStatus(HttpServletResponse.SC_OK); + response.setContentType(SmileMediaTypes.APPLICATION_JACKSON_SMILE); + smileMapper.writerFor(batchType) + .writeValue(asyncContext.getResponse().getOutputStream(), result); + response.getOutputStream().close(); + asyncContext.complete(); + } + catch (Exception e) { + log.noStackTrace().warn(e, "Could not respond to request from[%s] to[%s]", remoteAddr, requestURI); + } + } + } + + @Override + public void onFailure(Throwable e) + { + if (didRespond.compareAndSet(false, true)) { + try { + HttpServletResponse response = (HttpServletResponse) asyncContext.getResponse(); + response.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + asyncContext.complete(); + } + catch (Exception e2) { + e.addSuppressed(e2); + } + + log.noStackTrace().warn(e, "Request failed from[%s] to[%s]", remoteAddr, requestURI); + } + } + }, + Execs.directExecutor() + ); + + return null; + } +} diff --git a/server/src/main/java/org/apache/druid/messages/server/Outbox.java b/server/src/main/java/org/apache/druid/messages/server/Outbox.java new file mode 100644 index 000000000000..4fcf130f0a9f --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/server/Outbox.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.server; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.client.MessageRelay; + +/** + * An outbox for messages sent from servers to clients. Messages are retrieved in the order they are sent. + * + * @see org.apache.druid.messages package-level javadoc for description of the message relay system + */ +public interface Outbox +{ + /** + * Send a message to a client, through an outbox. + * + * @param clientHost which outbox to send messages through. Each clientHost has its own outbox. + * @param message message to send + * + * @return future that resolves successfully when the client has acknowledged the message + */ + ListenableFuture sendMessage(String clientHost, MessageType message); + + /** + * Get the next batch of messages for an client, from an outbox. Messages are retrieved in the order they were sent. + * + * The provided epoch must either be {@link MessageRelay#INIT}, or must match the epoch of the outbox as indicated by + * {@link MessageBatch#getEpoch()} returned by previous calls to the same outbox. If the provided epoch does not + * match, an empty batch is returned with the correct epoch indicated in {@link MessageBatch#getEpoch()}. + * + * The provided watermark must be greater than, or equal to, the previous watermark supplied to the same outbox. + * Any messages lower than the watermark are acknowledged and removed from the outbox. + * + * @param clientHost which outbox to retrieve messages from. Each clientHost has its own outbox. + * @param epoch outbox epoch, or {@link MessageRelay#INIT} if this is the first call from the collector. + * @param startWatermark outbox message watermark to retrieve from. + * + * @return future that resolves to the next batch of messages + */ + ListenableFuture> getMessages(String clientHost, long epoch, long startWatermark); + + /** + * Reset the outbox for a particular client. This removes all messages, cancels all outstanding futures, and + * resets the epoch. + * + * @param clientHost the client host:port + */ + void resetOutbox(String clientHost); +} diff --git a/server/src/main/java/org/apache/druid/messages/server/OutboxImpl.java b/server/src/main/java/org/apache/druid/messages/server/OutboxImpl.java new file mode 100644 index 000000000000..09e19177b945 --- /dev/null +++ b/server/src/main/java/org/apache/druid/messages/server/OutboxImpl.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.server; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.common.guava.FutureBox; +import org.apache.druid.common.guava.FutureUtils; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.lifecycle.LifecycleStop; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.client.MessageRelay; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadLocalRandom; + +/** + * Production implementation of {@link Outbox}. Each outbox is represented by an {@link OutboxQueue}. + */ +public class OutboxImpl implements Outbox +{ + private static final int MAX_BATCH_SIZE = 8; + + // clientHost -> outgoing message queue + private final ConcurrentHashMap> queues; + private volatile boolean stopped; + + public OutboxImpl() + { + this.queues = new ConcurrentHashMap<>(); + } + + @LifecycleStop + public void stop() + { + stopped = true; + + final Iterator> it = queues.values().iterator(); + while (it.hasNext()) { + it.next().stop(); + it.remove(); + } + } + + @Override + public ListenableFuture sendMessage(String clientHost, MessageType message) + { + if (stopped) { + return Futures.immediateCancelledFuture(); + } + + return queues.computeIfAbsent(clientHost, id -> new OutboxQueue<>()) + .sendMessage(message); + } + + @Override + public ListenableFuture> getMessages(String clientHost, long epoch, long startWatermark) + { + if (stopped) { + return Futures.immediateCancelledFuture(); + } + + final OutboxQueue queue = queues.computeIfAbsent(clientHost, id -> new OutboxQueue<>()); + if (epoch != queue.epoch && epoch != MessageRelay.INIT) { + return Futures.immediateFuture(new MessageBatch<>(Collections.emptyList(), queue.epoch, 0)); + } + + return queue.getMessages(startWatermark); + } + + @Override + public void resetOutbox(final String clientHost) + { + final OutboxQueue queue = queues.remove(clientHost); + if (queue != null) { + queue.stop(); + } + } + + @VisibleForTesting + long getOutboxEpoch(final String clientHost) + { + final OutboxQueue queue = queues.get(clientHost); + return queue != null ? queue.epoch : MessageRelay.INIT; + } + + /** + * Outgoing queue for a specific client. + */ + public static class OutboxQueue + { + /** + * Epoch, set when the outbox is created. Attached to returned batches through {@link MessageBatch#getEpoch()}. + */ + private final long epoch; + + /** + * Currently-outstanding futures. + */ + private final FutureBox pendingFutures = new FutureBox(); + + @GuardedBy("this") + private long startWatermark = 0; + + @GuardedBy("this") + private final Deque, T>> queue = new ArrayDeque<>(); + + @GuardedBy("this") + private SettableFuture messageAvailableFuture = SettableFuture.create(); + + public OutboxQueue() + { + // Random positive number, to differentiate this outbox from a previous version that may have lived + // on the same host. (When the upstream relay connects, it needs to know if this is the "same" outbox + // it was previously listening to.) + this.epoch = ThreadLocalRandom.current().nextLong() & Long.MAX_VALUE; + } + + ListenableFuture sendMessage(final T message) + { + final SettableFuture future = SettableFuture.create(); + + synchronized (this) { + queue.add(Pair.of(future, message)); + if (!messageAvailableFuture.isDone()) { + messageAvailableFuture.set(null); + } + } + + return pendingFutures.register(future); + } + + ListenableFuture> getMessages(final long newStartWatermark) + { + synchronized (this) { + // Ack and drain all messages up to startWatermark. + while (!queue.isEmpty() && startWatermark < newStartWatermark) { + final Pair, T> message = queue.poll(); + startWatermark++; + message.lhs.set(null); + } + + if (queue.isEmpty()) { + // Send next batch when a message is available. + if (messageAvailableFuture.isDone()) { + messageAvailableFuture = SettableFuture.create(); + } + + return pendingFutures.register( + FutureUtils.transform( + Futures.nonCancellationPropagating(messageAvailableFuture), + ignored -> { + synchronized (this) { + return nextBatch(); + } + } + ) + ); + } else { + return pendingFutures.register(Futures.immediateFuture(nextBatch())); + } + } + } + + void stop() + { + pendingFutures.close(); + } + + @GuardedBy("this") + private MessageBatch nextBatch() + { + final List batch = new ArrayList<>(); + final Iterator, T>> it = queue.iterator(); + + while (it.hasNext() && batch.size() < MAX_BATCH_SIZE) { + batch.add(it.next().rhs); + } + + return new MessageBatch<>(batch, epoch, startWatermark); + } + } +} diff --git a/server/src/main/java/org/apache/druid/rpc/FixedServiceLocator.java b/server/src/main/java/org/apache/druid/rpc/FixedServiceLocator.java new file mode 100644 index 000000000000..06e7bd993c18 --- /dev/null +++ b/server/src/main/java/org/apache/druid/rpc/FixedServiceLocator.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.rpc; + +import com.google.common.base.Preconditions; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; + +/** + * Locator for a fixed set of {@link ServiceLocations}. + */ +public class FixedServiceLocator implements ServiceLocator +{ + private final ServiceLocations locations; + + private volatile boolean closed = false; + + public FixedServiceLocator(final ServiceLocations locations) + { + this.locations = Preconditions.checkNotNull(locations); + } + + public FixedServiceLocator(final ServiceLocation location) + { + this(ServiceLocations.forLocation(location)); + } + + @Override + public ListenableFuture locate() + { + if (closed) { + return Futures.immediateFuture(ServiceLocations.closed()); + } else { + return Futures.immediateFuture(locations); + } + } + + @Override + public void close() + { + closed = true; + } +} diff --git a/server/src/main/java/org/apache/druid/rpc/FixedSetServiceLocator.java b/server/src/main/java/org/apache/druid/rpc/FixedSetServiceLocator.java deleted file mode 100644 index d6f6eff9d7fd..000000000000 --- a/server/src/main/java/org/apache/druid/rpc/FixedSetServiceLocator.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.druid.rpc; - -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import org.apache.druid.server.coordination.DruidServerMetadata; -import org.jboss.netty.util.internal.ThreadLocalRandom; - -import javax.validation.constraints.NotNull; -import java.util.Set; -import java.util.stream.Collectors; - -/** - * Basic implmentation of {@link ServiceLocator} that returns a service location from a static set of locations. Returns - * a random location each time one is requested. - */ -public class FixedSetServiceLocator implements ServiceLocator -{ - private ServiceLocations serviceLocations; - - private FixedSetServiceLocator(ServiceLocations serviceLocations) - { - this.serviceLocations = serviceLocations; - } - - public static FixedSetServiceLocator forServiceLocation(@NotNull ServiceLocation serviceLocation) - { - return new FixedSetServiceLocator(ServiceLocations.forLocation(serviceLocation)); - } - - public static FixedSetServiceLocator forDruidServerMetadata(Set serverMetadataSet) - { - if (serverMetadataSet == null || serverMetadataSet.isEmpty()) { - return new FixedSetServiceLocator(ServiceLocations.closed()); - } else { - Set serviceLocationSet = serverMetadataSet.stream() - .map(ServiceLocation::fromDruidServerMetadata) - .collect(Collectors.toSet()); - - return new FixedSetServiceLocator(ServiceLocations.forLocations(serviceLocationSet)); - } - } - - @Override - public ListenableFuture locate() - { - if (serviceLocations.isClosed() || serviceLocations.getLocations().isEmpty()) { - return Futures.immediateFuture(ServiceLocations.closed()); - } - - Set locationSet = serviceLocations.getLocations(); - int size = locationSet.size(); - if (size == 1) { - return Futures.immediateFuture(ServiceLocations.forLocation(locationSet.stream().findFirst().get())); - } - - return Futures.immediateFuture( - ServiceLocations.forLocation( - locationSet.stream() - .skip(ThreadLocalRandom.current().nextInt(size)) - .findFirst() - .orElse(null) - ) - ); - } - - @Override - public void close() - { - serviceLocations = ServiceLocations.closed(); - } -} diff --git a/server/src/main/java/org/apache/druid/rpc/ServiceClientImpl.java b/server/src/main/java/org/apache/druid/rpc/ServiceClientImpl.java index 3178360016ab..172f220fabad 100644 --- a/server/src/main/java/org/apache/druid/rpc/ServiceClientImpl.java +++ b/server/src/main/java/org/apache/druid/rpc/ServiceClientImpl.java @@ -41,7 +41,6 @@ import javax.annotation.Nullable; import java.net.URI; -import java.net.URISyntaxException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -497,19 +496,7 @@ static long computeBackoffMs(final ServiceRetryPolicy retryPolicy, final long at } /** - * Sanitizes IPv6 address if it has brackets. Eg. host = "[1:2:3:4:5:6:7:8]" will be returned as "1:2:3:4:5:6:7:8" - * after this function - */ - static String sanitizeHost(String host) - { - if (host.charAt(0) == '[') { - return host.substring(1, host.length() - 1); - } - return host; - } - - /** - * Returns a {@link ServiceLocation} without a path component, based on a URI. + * Returns a {@link ServiceLocation} without a path component, based on a URI. Returns null on invalid URIs. */ @Nullable @VisibleForTesting @@ -520,24 +507,17 @@ static ServiceLocation serviceLocationNoPathFromUri(@Nullable final String uriSt } try { - final URI uri = new URI(uriString); - - if (uri.getHost() == null) { - return null; - } - - final String scheme = uri.getScheme(); - final String host = sanitizeHost(uri.getHost()); - - if ("http".equals(scheme)) { - return new ServiceLocation(host, uri.getPort() < 0 ? 80 : uri.getPort(), -1, ""); - } else if ("https".equals(scheme)) { - return new ServiceLocation(host, -1, uri.getPort() < 0 ? 443 : uri.getPort(), ""); - } else { - return null; - } + final ServiceLocation location = ServiceLocation.fromUri(URI.create(uriString)); + + // Strip path. + return new ServiceLocation( + location.getHost(), + location.getPlaintextPort(), + location.getTlsPort(), + "" + ); } - catch (URISyntaxException e) { + catch (IllegalArgumentException e) { return null; } } @@ -549,8 +529,8 @@ static ServiceLocation serviceLocationNoPathFromUri(@Nullable final String uriSt static boolean serviceLocationMatches(final ServiceLocation left, final ServiceLocation right) { return left.getHost().equals(right.getHost()) - && portMatches(left.getPlaintextPort(), right.getPlaintextPort()) - && portMatches(left.getTlsPort(), right.getTlsPort()); + && portMatches(left.getPlaintextPort(), right.getPlaintextPort()) + && portMatches(left.getTlsPort(), right.getTlsPort()); } static boolean portMatches(int left, int right) diff --git a/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java b/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java index aeaa24318e93..974f09fe89bb 100644 --- a/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java +++ b/server/src/main/java/org/apache/druid/rpc/ServiceLocation.java @@ -22,6 +22,7 @@ import com.google.common.base.Preconditions; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableList; +import org.apache.druid.java.util.common.IAE; import org.apache.druid.java.util.common.ISE; import org.apache.druid.server.DruidNode; import org.apache.druid.server.coordination.DruidServerMetadata; @@ -29,6 +30,7 @@ import javax.annotation.Nullable; import javax.validation.constraints.NotNull; import java.net.MalformedURLException; +import java.net.URI; import java.net.URL; import java.util.Iterator; import java.util.Objects; @@ -40,6 +42,8 @@ public class ServiceLocation { private static final String HTTP_SCHEME = "http"; private static final String HTTPS_SCHEME = "https"; + private static final int HTTP_DEFAULT_PORT = 80; + private static final int HTTPS_DEFAULT_PORT = 443; private static final Splitter HOST_SPLITTER = Splitter.on(":").limit(2); private final String host; @@ -72,6 +76,50 @@ public static ServiceLocation fromDruidNode(final DruidNode druidNode) return new ServiceLocation(druidNode.getHost(), druidNode.getPlaintextPort(), druidNode.getTlsPort(), ""); } + /** + * Create a service location based on a {@link URI}. + * + * @throws IllegalArgumentException if the URI cannot be mapped to a service location. + */ + public static ServiceLocation fromUri(final URI uri) + { + if (uri == null || uri.getHost() == null) { + throw new IAE("URI[%s] has no host", uri); + } + + final String scheme = uri.getScheme(); + final String host = stripBrackets(uri.getHost()); + final StringBuilder basePath = new StringBuilder(); + + if (uri.getRawPath() != null) { + if (uri.getRawQuery() == null && uri.getRawFragment() == null && uri.getRawPath().endsWith("/")) { + // Strip trailing slash if the URI has no query or fragment. By convention, this trailing slash is not + // part of the service location. + basePath.append(uri.getRawPath(), 0, uri.getRawPath().length() - 1); + } else { + basePath.append(uri.getRawPath()); + } + } + + if (uri.getRawQuery() != null) { + basePath.append('?').append(uri.getRawQuery()); + } + + if (uri.getRawFragment() != null) { + basePath.append('#').append(uri.getRawFragment()); + } + + if (HTTP_SCHEME.equals(scheme)) { + final int port = uri.getPort() < 0 ? HTTP_DEFAULT_PORT : uri.getPort(); + return new ServiceLocation(host, port, -1, basePath.toString()); + } else if (HTTPS_SCHEME.equals(scheme)) { + final int port = uri.getPort() < 0 ? HTTPS_DEFAULT_PORT : uri.getPort(); + return new ServiceLocation(host, -1, port, basePath.toString()); + } else { + throw new IAE("URI[%s] has invalid scheme[%s]", uri, scheme); + } + } + /** * Create a service location based on a {@link DruidServerMetadata}. * @@ -133,6 +181,11 @@ public String getBasePath() return basePath; } + public ServiceLocation withBasePath(final String newBasePath) + { + return new ServiceLocation(host, plaintextPort, tlsPort, newBasePath); + } + public URL toURL(@Nullable final String encodedPathAndQueryString) { final String scheme; @@ -193,4 +246,15 @@ public String toString() '}'; } + /** + * Strips brackers from the host part of a URI, so we can better handle IPv6 addresses. + * e.g. host = "[1:2:3:4:5:6:7:8]" is transformed to "1:2:3:4:5:6:7:8" by this function + */ + static String stripBrackets(String host) + { + if (host.charAt(0) == '[' && host.charAt(host.length() - 1) == ']') { + return host.substring(1, host.length() - 1); + } + return host; + } } diff --git a/server/src/test/java/org/apache/druid/client/BrokerServerViewTest.java b/server/src/test/java/org/apache/druid/client/BrokerServerViewTest.java index 798b55ed7274..fd90ff905a22 100644 --- a/server/src/test/java/org/apache/druid/client/BrokerServerViewTest.java +++ b/server/src/test/java/org/apache/druid/client/BrokerServerViewTest.java @@ -25,6 +25,7 @@ import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Sets; @@ -65,6 +66,7 @@ import java.util.Set; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; +import java.util.stream.Collectors; public class BrokerServerViewTest extends CuratorTestBase { @@ -290,6 +292,12 @@ public void testMultipleServerAndBroker() throws Exception ) ); + // check server metadatas + Assert.assertEquals( + druidServers.stream().map(DruidServer::getMetadata).collect(Collectors.toSet()), + ImmutableSet.copyOf(brokerServerView.getDruidServerMetadatas()) + ); + // unannounce the broker segment should do nothing to announcements unannounceSegmentForServer(druidBroker, brokerSegment, zkPathsConfig); Assert.assertTrue(timing.forWaiting().awaitLatch(segmentRemovedLatch)); @@ -593,7 +601,8 @@ private void setupViews() throws Exception setupViews(null, null, true); } - private void setupViews(Set watchedTiers, Set ignoredTiers, boolean watchRealtimeTasks) throws Exception + private void setupViews(Set watchedTiers, Set ignoredTiers, boolean watchRealtimeTasks) + throws Exception { baseView = new BatchServerInventoryView( zkPathsConfig, diff --git a/server/src/test/java/org/apache/druid/messages/MessageBatchTest.java b/server/src/test/java/org/apache/druid/messages/MessageBatchTest.java new file mode 100644 index 000000000000..bcf9fb3423d1 --- /dev/null +++ b/server/src/test/java/org/apache/druid/messages/MessageBatchTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableList; +import nl.jqno.equalsverifier.EqualsVerifier; +import org.apache.druid.segment.TestHelper; +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; + +public class MessageBatchTest +{ + @Test + public void test_serde() throws IOException + { + final ObjectMapper objectMapper = TestHelper.JSON_MAPPER; + final MessageBatch batch = new MessageBatch<>(ImmutableList.of("foo", "bar"), 123L, 456L); + final MessageBatch batch2 = + objectMapper.readValue(objectMapper.writeValueAsBytes(batch), new TypeReference>() {}); + Assert.assertEquals(batch, batch2); + } + + @Test + public void test_equals() + { + EqualsVerifier.forClass(MessageBatch.class) + .usingGetClass() + .verify(); + } +} diff --git a/server/src/test/java/org/apache/druid/messages/client/MessageRelayClientImplTest.java b/server/src/test/java/org/apache/druid/messages/client/MessageRelayClientImplTest.java new file mode 100644 index 000000000000..7b8af75c8d41 --- /dev/null +++ b/server/src/test/java/org/apache/druid/messages/client/MessageRelayClientImplTest.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.smile.SmileFactory; +import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.util.concurrent.ListenableFuture; +import it.unimi.dsi.fastutil.bytes.ByteArrays; +import org.apache.druid.jackson.DefaultObjectMapper; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.rpc.MockServiceClient; +import org.apache.druid.rpc.RequestBuilder; +import org.jboss.netty.handler.codec.http.HttpMethod; +import org.jboss.netty.handler.codec.http.HttpResponseStatus; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import javax.ws.rs.core.HttpHeaders; +import java.util.Collections; + +public class MessageRelayClientImplTest +{ + private ObjectMapper smileMapper; + private MockServiceClient serviceClient; + private MessageRelayClient messageRelayClient; + + @Before + public void setup() + { + smileMapper = new DefaultObjectMapper(new SmileFactory(), null); + serviceClient = new MockServiceClient(); + messageRelayClient = new MessageRelayClientImpl<>(serviceClient, smileMapper, String.class); + } + + @After + public void tearDown() + { + serviceClient.verify(); + } + + @Test + public void test_getMessages_ok() throws Exception + { + final MessageBatch batch = new MessageBatch<>(ImmutableList.of("foo", "bar"), 123, 0); + + serviceClient.expectAndRespond( + new RequestBuilder(HttpMethod.GET, "/outbox/me/messages?epoch=-1&watermark=0"), + HttpResponseStatus.OK, + ImmutableMap.of(HttpHeaders.CONTENT_TYPE, SmileMediaTypes.APPLICATION_JACKSON_SMILE), + smileMapper.writeValueAsBytes(batch) + ); + + final ListenableFuture> result = messageRelayClient.getMessages("me", MessageRelay.INIT, 0); + Assert.assertEquals(batch, result.get()); + } + + @Test + public void test_getMessages_noContent() throws Exception + { + serviceClient.expectAndRespond( + new RequestBuilder(HttpMethod.GET, "/outbox/me/messages?epoch=-1&watermark=0"), + HttpResponseStatus.NO_CONTENT, + ImmutableMap.of(HttpHeaders.CONTENT_TYPE, SmileMediaTypes.APPLICATION_JACKSON_SMILE), + ByteArrays.EMPTY_ARRAY + ); + + final ListenableFuture> result = messageRelayClient.getMessages("me", MessageRelay.INIT, 0); + Assert.assertEquals(new MessageBatch<>(Collections.emptyList(), MessageRelay.INIT, 0), result.get()); + } +} diff --git a/server/src/test/java/org/apache/druid/messages/client/MessageRelaysTest.java b/server/src/test/java/org/apache/druid/messages/client/MessageRelaysTest.java new file mode 100644 index 000000000000..b2014450d81e --- /dev/null +++ b/server/src/test/java/org/apache/druid/messages/client/MessageRelaysTest.java @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.client; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.errorprone.annotations.concurrent.GuardedBy; +import org.apache.druid.discovery.DiscoveryDruidNode; +import org.apache.druid.discovery.DruidNodeDiscovery; +import org.apache.druid.discovery.NodeRole; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.server.Outbox; +import org.apache.druid.messages.server.OutboxImpl; +import org.apache.druid.server.DruidNode; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; + +public class MessageRelaysTest +{ + private static final String MY_HOST = "me"; + private static final DruidNode OUTBOX_NODE = new DruidNode("service", "host", false, 80, -1, true, false); + private static final DiscoveryDruidNode OUTBOX_DISCO_NODE = new DiscoveryDruidNode( + new DruidNode("service", "host", false, 80, -1, true, false), + NodeRole.HISTORICAL, + Collections.emptyMap() + ); + + private Outbox outbox; + private TestMessageListener messageListener; + private TestDiscovery discovery; + private MessageRelays messageRelays; + + @Before + public void setUp() + { + outbox = new OutboxImpl<>(); + messageListener = new TestMessageListener(); + discovery = new TestDiscovery(); + messageRelays = new MessageRelays<>( + () -> discovery, + node -> { + Assert.assertEquals(OUTBOX_NODE, node); + return new MessageRelay<>( + MY_HOST, + node, + new OutboxMessageRelayClient(outbox), + messageListener + ); + } + ); + messageRelays.start(); + } + + @After + public void tearDown() + { + messageRelays.stop(); + Assert.assertEquals(Collections.emptyList(), discovery.getListeners()); + } + + @Test + public void test_serverAdded_thenRemoved() + { + discovery.fire(listener -> listener.nodesAdded(Collections.singletonList(OUTBOX_DISCO_NODE))); + discovery.fire(listener -> listener.nodesRemoved(Collections.singletonList(OUTBOX_DISCO_NODE))); + Assert.assertEquals(1, messageListener.getAdds()); + Assert.assertEquals(1, messageListener.getRemoves()); + } + + @Test + public void test_messageListener() + { + discovery.fire(listener -> listener.nodesAdded(Collections.singletonList(OUTBOX_DISCO_NODE))); + Assert.assertEquals(1, messageListener.getAdds()); + Assert.assertEquals(0, messageListener.getRemoves()); + + final ListenableFuture sendFuture = outbox.sendMessage(MY_HOST, "foo"); + Assert.assertEquals(ImmutableList.of("foo"), messageListener.getMessages()); + Assert.assertTrue(sendFuture.isDone()); + + final ListenableFuture sendFuture2 = outbox.sendMessage(MY_HOST, "bar"); + Assert.assertEquals(ImmutableList.of("foo", "bar"), messageListener.getMessages()); + Assert.assertTrue(sendFuture2.isDone()); + } + + /** + * Implementation of {@link MessageListener} that tracks all received messages. + */ + private static class TestMessageListener implements MessageListener + { + @GuardedBy("this") + private long adds; + + @GuardedBy("this") + private long removes; + + @GuardedBy("this") + private final List messages = new ArrayList<>(); + + @Override + public synchronized void serverAdded(DruidNode node) + { + adds++; + } + + @Override + public synchronized void messageReceived(String message) + { + messages.add(message); + } + + @Override + public synchronized void serverRemoved(DruidNode node) + { + removes++; + } + + public synchronized long getAdds() + { + return adds; + } + + public synchronized long getRemoves() + { + return removes; + } + + public synchronized List getMessages() + { + return ImmutableList.copyOf(messages); + } + } + + /** + * Implementation of {@link MessageRelayClient} that directly uses an {@link Outbox}, rather than contacting + * a remote outbox. + */ + private static class OutboxMessageRelayClient implements MessageRelayClient + { + private final Outbox outbox; + + public OutboxMessageRelayClient(final Outbox outbox) + { + this.outbox = outbox; + } + + @Override + public ListenableFuture> getMessages(String clientHost, long epoch, long startWatermark) + { + return outbox.getMessages(clientHost, epoch, startWatermark); + } + } + + /** + * Implementation of {@link DruidNodeDiscovery} that allows firing listeners on command. + */ + private static class TestDiscovery implements DruidNodeDiscovery + { + @GuardedBy("this") + private final List listeners; + + public TestDiscovery() + { + listeners = new ArrayList<>(); + } + + @Override + public Collection getAllNodes() + { + throw new UnsupportedOperationException(); + } + + @Override + public synchronized void registerListener(Listener listener) + { + listeners.add(listener); + } + + @Override + public synchronized void removeListener(Listener listener) + { + listeners.remove(listener); + } + + public synchronized List getListeners() + { + return ImmutableList.copyOf(listeners); + } + + public synchronized void fire(Consumer f) + { + for (final Listener listener : listeners) { + f.accept(listener); + } + } + } +} diff --git a/server/src/test/java/org/apache/druid/messages/server/OutboxImplTest.java b/server/src/test/java/org/apache/druid/messages/server/OutboxImplTest.java new file mode 100644 index 000000000000..727c1c6ee2fc --- /dev/null +++ b/server/src/test/java/org/apache/druid/messages/server/OutboxImplTest.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.messages.server; + +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.druid.messages.MessageBatch; +import org.apache.druid.messages.client.MessageRelay; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Collections; +import java.util.concurrent.ExecutionException; + +public class OutboxImplTest +{ + private static final String HOST = "h1"; + + private OutboxImpl outbox; + + @Before + public void setUp() + { + outbox = new OutboxImpl<>(); + } + + @After + public void tearDown() + { + outbox.stop(); + } + + @Test + public void test_normalOperation() throws InterruptedException, ExecutionException + { + // Send first three messages. + final ListenableFuture sendFuture1 = outbox.sendMessage(HOST, "1"); + final ListenableFuture sendFuture2 = outbox.sendMessage(HOST, "2"); + final ListenableFuture sendFuture3 = outbox.sendMessage(HOST, "3"); + + final long outboxEpoch = outbox.getOutboxEpoch(HOST); + + // No messages are acknowledged. + Assert.assertFalse(sendFuture1.isDone()); + Assert.assertFalse(sendFuture2.isDone()); + Assert.assertFalse(sendFuture3.isDone()); + + // Request all three messages (startWatermark = 0). + Assert.assertEquals( + new MessageBatch<>(ImmutableList.of("1", "2", "3"), outboxEpoch, 0), + outbox.getMessages(HOST, MessageRelay.INIT, 0).get() + ); + + // No messages are acknowledged. + Assert.assertFalse(sendFuture1.isDone()); + Assert.assertFalse(sendFuture2.isDone()); + Assert.assertFalse(sendFuture3.isDone()); + + // Request two of those messages again (startWatermark = 1). + Assert.assertEquals( + new MessageBatch<>(ImmutableList.of("2", "3"), outboxEpoch, 1), + outbox.getMessages(HOST, outboxEpoch, 1).get() + ); + + // First message is acknowledged. + Assert.assertTrue(sendFuture1.isDone()); + Assert.assertFalse(sendFuture2.isDone()); + Assert.assertFalse(sendFuture3.isDone()); + + // Request the high watermark (startWatermark = 3). + final ListenableFuture> futureBatch = outbox.getMessages(HOST, outboxEpoch, 3); + + // It's not available yet. + Assert.assertFalse(futureBatch.isDone()); + + // All messages are acknowledged. + Assert.assertTrue(sendFuture1.isDone()); + Assert.assertTrue(sendFuture2.isDone()); + Assert.assertTrue(sendFuture3.isDone()); + + // Send one more message; futureBatch resolves. + final ListenableFuture sendFuture4 = outbox.sendMessage(HOST, "4"); + Assert.assertTrue(futureBatch.isDone()); + + // sendFuture4 is not resolved. + Assert.assertFalse(sendFuture4.isDone()); + } + + @Test + public void test_getMessages_wrongEpoch() throws InterruptedException, ExecutionException + { + final ListenableFuture sendFuture = outbox.sendMessage(HOST, "1"); + final long outboxEpoch = outbox.getOutboxEpoch(HOST); + + // Fetch with the wrong epoch. + final MessageBatch batch = outbox.getMessages(HOST, outboxEpoch + 1, 0).get(); + Assert.assertEquals( + new MessageBatch<>(Collections.emptyList(), outboxEpoch, 0), + batch + ); + + Assert.assertFalse(sendFuture.isDone()); + } + + @Test + public void test_getMessages_nonexistentHost() throws InterruptedException, ExecutionException + { + // Calling getMessages with a nonexistent host creates an outbox. + final String nonexistentHost = "nonexistent"; + final ListenableFuture> batchFuture = outbox.getMessages( + nonexistentHost, + MessageRelay.INIT, + 0 + ); + Assert.assertFalse(batchFuture.isDone()); + + // Check that an outbox was created (it has an epoch). + MatcherAssert.assertThat(outbox.getOutboxEpoch(nonexistentHost), Matchers.greaterThanOrEqualTo(0L)); + + // getMessages future resolves when a message is sent. + final ListenableFuture sendFuture = outbox.sendMessage(nonexistentHost, "foo"); + Assert.assertTrue(batchFuture.isDone()); + Assert.assertEquals( + new MessageBatch<>(ImmutableList.of("foo"), outbox.getOutboxEpoch(nonexistentHost), 0), + batchFuture.get() + ); + + // As usual, sendFuture resolves when the high watermark is requested. + Assert.assertFalse(sendFuture.isDone()); + final ListenableFuture> batchFuture2 = + outbox.getMessages(nonexistentHost, outbox.getOutboxEpoch(nonexistentHost), 1); + + Assert.assertTrue(sendFuture.isDone()); + + outbox.resetOutbox(nonexistentHost); + Assert.assertTrue(batchFuture2.isDone()); + } + + @Test + public void test_stop_cancelsSendMessage() + { + final ListenableFuture sendFuture = outbox.sendMessage(HOST, "1"); + outbox.stop(); + Assert.assertTrue(sendFuture.isCancelled()); + } + + @Test + public void test_stop_cancelsGetMessages() + { + final ListenableFuture> futureBatch = outbox.getMessages(HOST, MessageRelay.INIT, 0); + outbox.stop(); + Assert.assertTrue(futureBatch.isCancelled()); + } + + @Test + public void test_reset_cancelsSendMessage() + { + final ListenableFuture sendFuture = outbox.sendMessage(HOST, "1"); + outbox.resetOutbox(HOST); + Assert.assertTrue(sendFuture.isCancelled()); + } + + @Test + public void test_reset_cancelsGetMessages() + { + final ListenableFuture> futureBatch = outbox.getMessages(HOST, MessageRelay.INIT, 0); + outbox.resetOutbox(HOST); + Assert.assertTrue(futureBatch.isCancelled()); + } + + @Test + public void test_reset_nonexistentHost_doesNothing() + { + outbox.resetOutbox("nonexistent"); + } + + @Test + public void test_stop_preventsSendMessage() + { + outbox.stop(); + final ListenableFuture sendFuture = outbox.sendMessage(HOST, "1"); + Assert.assertTrue(sendFuture.isCancelled()); + } + + @Test + public void test_stop_preventsGetMessages() + { + outbox.stop(); + final ListenableFuture> futureBatch = outbox.getMessages(HOST, MessageRelay.INIT, 0); + Assert.assertTrue(futureBatch.isCancelled()); + } +} diff --git a/server/src/test/java/org/apache/druid/rpc/FixedSetServiceLocatorTest.java b/server/src/test/java/org/apache/druid/rpc/FixedServiceLocatorTest.java similarity index 56% rename from server/src/test/java/org/apache/druid/rpc/FixedSetServiceLocatorTest.java rename to server/src/test/java/org/apache/druid/rpc/FixedServiceLocatorTest.java index b0b92f5e271b..c7775bfac833 100644 --- a/server/src/test/java/org/apache/druid/rpc/FixedSetServiceLocatorTest.java +++ b/server/src/test/java/org/apache/druid/rpc/FixedServiceLocatorTest.java @@ -27,7 +27,7 @@ import java.util.concurrent.ExecutionException; -public class FixedSetServiceLocatorTest +public class FixedServiceLocatorTest { public static final DruidServerMetadata DATA_SERVER_1 = new DruidServerMetadata( "TestDataServer", @@ -50,19 +50,24 @@ public class FixedSetServiceLocatorTest ); @Test - public void testLocateNullShouldBeClosed() throws ExecutionException, InterruptedException + public void test_constructor_rejectsNull() { - FixedSetServiceLocator serviceLocator - = FixedSetServiceLocator.forDruidServerMetadata(null); + Assert.assertThrows( + NullPointerException.class, + () -> new FixedServiceLocator((ServiceLocation) null) + ); - Assert.assertTrue(serviceLocator.locate().get().isClosed()); + Assert.assertThrows( + NullPointerException.class, + () -> new FixedServiceLocator((ServiceLocations) null) + ); } @Test - public void testLocateSingleServer() throws ExecutionException, InterruptedException + public void test_locate_singleServer() throws ExecutionException, InterruptedException { - FixedSetServiceLocator serviceLocator - = FixedSetServiceLocator.forDruidServerMetadata(ImmutableSet.of(DATA_SERVER_1)); + FixedServiceLocator serviceLocator = + new FixedServiceLocator(ServiceLocation.fromDruidServerMetadata(DATA_SERVER_1)); Assert.assertEquals( ServiceLocations.forLocation(ServiceLocation.fromDruidServerMetadata(DATA_SERVER_1)), @@ -71,16 +76,30 @@ public void testLocateSingleServer() throws ExecutionException, InterruptedExcep } @Test - public void testLocateMultipleServers() throws ExecutionException, InterruptedException + public void test_locate_afterClose() throws ExecutionException, InterruptedException { - FixedSetServiceLocator serviceLocator - = FixedSetServiceLocator.forDruidServerMetadata(ImmutableSet.of(DATA_SERVER_1, DATA_SERVER_2)); + FixedServiceLocator serviceLocator = + new FixedServiceLocator(ServiceLocation.fromDruidServerMetadata(DATA_SERVER_1)); + + serviceLocator.close(); - Assert.assertTrue( + Assert.assertEquals( + ServiceLocations.closed(), + serviceLocator.locate().get() + ); + } + + @Test + public void test_locate_multipleServers() throws ExecutionException, InterruptedException + { + final ServiceLocations locations = ServiceLocations.forLocations( ImmutableSet.of( - ServiceLocations.forLocation(ServiceLocation.fromDruidServerMetadata(DATA_SERVER_1)), - ServiceLocations.forLocation(ServiceLocation.fromDruidServerMetadata(DATA_SERVER_2)) - ).contains(serviceLocator.locate().get()) + ServiceLocation.fromDruidServerMetadata(DATA_SERVER_1), + ServiceLocation.fromDruidServerMetadata(DATA_SERVER_2) + ) ); + + FixedServiceLocator serviceLocator = new FixedServiceLocator(locations); + Assert.assertEquals(locations, serviceLocator.locate().get()); } } diff --git a/server/src/test/java/org/apache/druid/rpc/ServiceClientImplTest.java b/server/src/test/java/org/apache/druid/rpc/ServiceClientImplTest.java index 69cb12e423ca..7346edd5cf6b 100644 --- a/server/src/test/java/org/apache/druid/rpc/ServiceClientImplTest.java +++ b/server/src/test/java/org/apache/druid/rpc/ServiceClientImplTest.java @@ -685,14 +685,6 @@ public void test_serviceLocationNoPathFromUri() ); } - @Test - public void test_normalizeHost() - { - Assert.assertEquals("1:2:3:4:5:6:7:8", ServiceClientImpl.sanitizeHost("[1:2:3:4:5:6:7:8]")); - Assert.assertEquals("1:2:3:4:5:6:7:8", ServiceClientImpl.sanitizeHost("1:2:3:4:5:6:7:8")); - Assert.assertEquals("1.2.3.4", ServiceClientImpl.sanitizeHost("1.2.3.4")); - } - @Test public void test_isRedirect() { diff --git a/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java b/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java index 6aec0e2b6060..8d95e79dd966 100644 --- a/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java +++ b/server/src/test/java/org/apache/druid/rpc/ServiceLocationTest.java @@ -25,8 +25,48 @@ import org.junit.Assert; import org.junit.Test; +import java.net.URI; + public class ServiceLocationTest { + @Test + public void test_stripBrackets() + { + Assert.assertEquals("1:2:3:4:5:6:7:8", ServiceLocation.stripBrackets("[1:2:3:4:5:6:7:8]")); + Assert.assertEquals("1:2:3:4:5:6:7:8", ServiceLocation.stripBrackets("1:2:3:4:5:6:7:8")); + Assert.assertEquals("1.2.3.4", ServiceLocation.stripBrackets("1.2.3.4")); + } + + @Test + public void test_fromUri_http() + { + final ServiceLocation location = ServiceLocation.fromUri(URI.create("http://example.com:8100/xyz")); + Assert.assertEquals("example.com", location.getHost()); + Assert.assertEquals(-1, location.getTlsPort()); + Assert.assertEquals(8100, location.getPlaintextPort()); + Assert.assertEquals("/xyz", location.getBasePath()); + } + + @Test + public void test_fromUri_https_defaultPort() + { + final ServiceLocation location = ServiceLocation.fromUri(URI.create("https://example.com/xyz")); + Assert.assertEquals("example.com", location.getHost()); + Assert.assertEquals(443, location.getTlsPort()); + Assert.assertEquals(-1, location.getPlaintextPort()); + Assert.assertEquals("/xyz", location.getBasePath()); + } + + @Test + public void test_fromUri_https() + { + final ServiceLocation location = ServiceLocation.fromUri(URI.create("https://example.com:8100/xyz")); + Assert.assertEquals("example.com", location.getHost()); + Assert.assertEquals(8100, location.getTlsPort()); + Assert.assertEquals(-1, location.getPlaintextPort()); + Assert.assertEquals("/xyz", location.getBasePath()); + } + @Test public void test_fromDruidServerMetadata_withPort() { diff --git a/services/src/main/java/org/apache/druid/cli/CliHistorical.java b/services/src/main/java/org/apache/druid/cli/CliHistorical.java index 2e231bcdcc3b..ea8bbd994348 100644 --- a/services/src/main/java/org/apache/druid/cli/CliHistorical.java +++ b/services/src/main/java/org/apache/druid/cli/CliHistorical.java @@ -42,6 +42,7 @@ import org.apache.druid.guice.ManageLifecycle; import org.apache.druid.guice.QueryRunnerFactoryModule; import org.apache.druid.guice.QueryableModule; +import org.apache.druid.guice.SegmentWranglerModule; import org.apache.druid.guice.ServerTypeConfig; import org.apache.druid.java.util.common.logger.Logger; import org.apache.druid.query.QuerySegmentWalker; @@ -99,6 +100,7 @@ protected List getModules() new DruidProcessingModule(), new QueryableModule(), new QueryRunnerFactoryModule(), + new SegmentWranglerModule(), new JoinableFactoryModule(), new HistoricalServiceModule(), binder -> { diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java index 92f2ef2ea811..dfd288f03d42 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/IngestHandler.java @@ -186,7 +186,8 @@ protected RelDataType returnedRowType() final RelDataTypeFactory typeFactory = rootQueryRel.rel.getCluster().getTypeFactory(); return handlerContext.engine().resultTypeForInsert( typeFactory, - rootQueryRel.validatedRowType + rootQueryRel.validatedRowType, + handlerContext.queryContextMap() ); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java b/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java index 82dd6afe8c99..a0676ef45efc 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/planner/QueryHandler.java @@ -736,7 +736,8 @@ protected RelDataType returnedRowType() final RelDataTypeFactory typeFactory = rootQueryRel.rel.getCluster().getTypeFactory(); return handlerContext.engine().resultTypeForSelect( typeFactory, - rootQueryRel.validatedRowType + rootQueryRel.validatedRowType, + handlerContext.plannerContext().queryContextMap() ); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java index d02d302437b8..4f3d86b1b420 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/NativeSqlEngine.java @@ -83,13 +83,21 @@ public void validateContext(Map queryContext) } @Override - public RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { return validatedRowType; } @Override - public RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { throw new UnsupportedOperationException(); } diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java b/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java index fec7660e44ef..1d33b019e684 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/run/SqlEngine.java @@ -57,8 +57,13 @@ public interface SqlEngine * * @param typeFactory type factory * @param validatedRowType row type from Calcite's validator + * @param queryContext query context, in case that affects the result type */ - RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType); + RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ); /** * SQL row type that would be emitted by the {@link QueryMaker} from {@link #buildQueryMakerForInsert}. @@ -66,8 +71,13 @@ public interface SqlEngine * * @param typeFactory type factory * @param validatedRowType row type from Calcite's validator + * @param queryContext query context, in case that affects the result type */ - RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType); + RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ); /** * Create a {@link QueryMaker} for a SELECT query. diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java b/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java index 716fa50b85f1..7563b45d52bc 100644 --- a/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java +++ b/sql/src/main/java/org/apache/druid/sql/calcite/view/ViewSqlEngine.java @@ -93,13 +93,21 @@ public void validateContext(Map queryContext) } @Override - public RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { return validatedRowType; } @Override - public RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { // Can't have views of INSERT or REPLACE statements. throw new UnsupportedOperationException(); diff --git a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java index 4adea5d8d84e..d957e7155b5e 100644 --- a/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java +++ b/sql/src/main/java/org/apache/druid/sql/http/SqlResource.java @@ -82,7 +82,7 @@ public class SqlResource private final DruidNode selfNode; @Inject - SqlResource( + protected SqlResource( final ObjectMapper jsonMapper, final AuthorizerMapper authorizerMapper, final @NativeQuery SqlStatementFactory sqlStatementFactory, @@ -140,19 +140,7 @@ public Response cancelQuery( return Response.status(Status.NOT_FOUND).build(); } - // Considers only datasource and table resources; not context - // key resources when checking permissions. This means that a user's - // permission to cancel a query depends on the datasource, not the - // context variables used in the query. - Set resources = lifecycles - .stream() - .flatMap(lifecycle -> lifecycle.resources().stream()) - .collect(Collectors.toSet()); - Access access = AuthorizationUtils.authorizeAllResourceActions( - req, - resources, - authorizerMapper - ); + final Access access = authorizeCancellation(req, lifecycles); if (access.isAllowed()) { // should remove only the lifecycles in the snapshot. @@ -341,4 +329,23 @@ public void writeException(Exception ex, OutputStream out) throws IOException out.write(jsonMapper.writeValueAsBytes(ex)); } } + + /** + * Authorize a query cancellation operation. + * + * Considers only datasource and table resources; not context key resources when checking permissions. This means + * that a user's permission to cancel a query depends on the datasource, not the context variables used in the query. + */ + public Access authorizeCancellation(final HttpServletRequest req, final List cancelables) + { + Set resources = cancelables + .stream() + .flatMap(lifecycle -> lifecycle.resources().stream()) + .collect(Collectors.toSet()); + return AuthorizationUtils.authorizeAllResourceActions( + req, + resources, + authorizerMapper + ); + } } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteScanSignatureTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteScanSignatureTest.java index abab053dd6bb..80a9dde9b4c9 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteScanSignatureTest.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteScanSignatureTest.java @@ -155,13 +155,21 @@ public void validateContext(Map queryContext) } @Override - public RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { return validatedRowType; } @Override - public RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { throw new UnsupportedOperationException(); } diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java b/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java index 466bd0e390bd..569598af1e4e 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/IngestionTestSqlEngine.java @@ -56,13 +56,21 @@ public void validateContext(Map queryContext) } @Override - public RelDataType resultTypeForSelect(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForSelect( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { throw new UnsupportedOperationException(); } @Override - public RelDataType resultTypeForInsert(RelDataTypeFactory typeFactory, RelDataType validatedRowType) + public RelDataType resultTypeForInsert( + RelDataTypeFactory typeFactory, + RelDataType validatedRowType, + Map queryContext + ) { // Matches the return structure of TestInsertQueryMaker. return typeFactory.createStructType( diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/util/TestTimelineServerView.java b/sql/src/test/java/org/apache/druid/sql/calcite/util/TestTimelineServerView.java index bd80aee8cdad..58990e806617 100644 --- a/sql/src/test/java/org/apache/druid/sql/calcite/util/TestTimelineServerView.java +++ b/sql/src/test/java/org/apache/druid/sql/calcite/util/TestTimelineServerView.java @@ -34,7 +34,6 @@ import org.apache.druid.timeline.DataSegment; import org.apache.druid.timeline.TimelineLookup; -import javax.annotation.Nullable; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -97,7 +96,6 @@ public Optional> getTimeline(Da throw new UnsupportedOperationException(); } - @Nullable @Override public List getDruidServers() {