From 464b153850bd3e370687604c939bccf57d0c2b4c Mon Sep 17 00:00:00 2001 From: Adarsh Sanjeev Date: Wed, 25 Sep 2024 09:15:32 +0530 Subject: [PATCH] Make sketch encoding configurable (#17086) Makes sketch encoding in MSQ configurable by the user. This would allow a user to configure the sketch encoding method for a specific query. The default is octet stream encoding. --- .../apache/druid/msq/exec/ControllerImpl.java | 3 +- .../exec/ExceptionWrappingWorkerClient.java | 11 ++++-- .../apache/druid/msq/exec/WorkerClient.java | 7 +++- .../druid/msq/exec/WorkerSketchFetcher.java | 11 ++++-- .../druid/msq/rpc/BaseWorkerClientImpl.java | 10 +++-- .../apache/druid/msq/rpc/SketchEncoding.java | 39 +++++++++++++++++++ .../apache/druid/msq/rpc/WorkerResource.java | 15 ------- .../msq/util/MultiStageQueryContext.java | 13 +++++++ .../msq/exec/WorkerSketchFetcherTest.java | 27 ++++++------- .../druid/msq/test/MSQTestWorkerClient.java | 7 +++- 10 files changed, 99 insertions(+), 44 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/SketchEncoding.java diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 72d8216088fe..019e998182f6 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -633,7 +633,8 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer) this.workerSketchFetcher = new WorkerSketchFetcher( netClient, workerManager, - queryKernelConfig.isFaultTolerant() + queryKernelConfig.isFaultTolerant(), + MultiStageQueryContext.getSketchEncoding(querySpec.getQuery().context()) ); closer.register(workerSketchFetcher::close); diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java index 8a7607d3159a..3373bbd883ee 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ExceptionWrappingWorkerClient.java @@ -32,6 +32,7 @@ import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.rpc.SketchEncoding; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; import javax.annotation.Nullable; @@ -60,23 +61,25 @@ public ListenableFuture postWorkOrder(String workerTaskId, WorkOrder workO @Override public ListenableFuture fetchClusterByStatisticsSnapshot( String workerTaskId, - StageId stageId + StageId stageId, + SketchEncoding sketchEncoding ) { - return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, stageId)); + return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, stageId, sketchEncoding)); } @Override public ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( String workerTaskId, StageId stageId, - long timeChunk + long timeChunk, + SketchEncoding sketchEncoding ) { return wrap( workerTaskId, client, - c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, stageId, timeChunk) + c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, stageId, timeChunk, sketchEncoding) ); } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java index 572051124a74..4e7d506815e1 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerClient.java @@ -25,6 +25,7 @@ import org.apache.druid.msq.counters.CounterSnapshotsTree; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.rpc.SketchEncoding; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; import java.io.Closeable; @@ -47,7 +48,8 @@ public interface WorkerClient extends Closeable */ ListenableFuture fetchClusterByStatisticsSnapshot( String workerId, - StageId stageId + StageId stageId, + SketchEncoding sketchEncoding ); /** @@ -57,7 +59,8 @@ ListenableFuture fetchClusterByStatisticsSnapshot( ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( String workerId, StageId stageId, - long timeChunk + long timeChunk, + SketchEncoding sketchEncoding ); /** diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java index 73f151fcdaa9..b0da5b83a46b 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/WorkerSketchFetcher.java @@ -34,6 +34,7 @@ import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.controller.ControllerQueryKernel; +import org.apache.druid.msq.rpc.SketchEncoding; import org.apache.druid.msq.statistics.ClusterByStatisticsCollector; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation; @@ -57,6 +58,7 @@ public class WorkerSketchFetcher implements AutoCloseable private static final int DEFAULT_THREAD_COUNT = 4; private final WorkerClient workerClient; + private final SketchEncoding sketchEncoding; private final WorkerManager workerManager; private final boolean retryEnabled; @@ -68,10 +70,12 @@ public class WorkerSketchFetcher implements AutoCloseable public WorkerSketchFetcher( WorkerClient workerClient, WorkerManager workerManager, - boolean retryEnabled + boolean retryEnabled, + SketchEncoding sketchEncoding ) { this.workerClient = workerClient; + this.sketchEncoding = sketchEncoding; this.executorService = Execs.multiThreaded(DEFAULT_THREAD_COUNT, "SketchFetcherThreadPool-%d"); this.workerManager = workerManager; this.retryEnabled = retryEnabled; @@ -96,7 +100,7 @@ public void inMemoryFullSketchMerging( executorService.submit(() -> { fetchStatsFromWorker( kernelActions, - () -> workerClient.fetchClusterByStatisticsSnapshot(taskId, stageId), + () -> workerClient.fetchClusterByStatisticsSnapshot(taskId, stageId, sketchEncoding), taskId, (kernel, snapshot) -> kernel.mergeClusterByStatisticsCollectorForAllTimeChunks(stageId, workerNumber, snapshot), @@ -252,7 +256,8 @@ public void sequentialTimeChunkMerging( () -> workerClient.fetchClusterByStatisticsSnapshotForTimeChunk( taskId, new StageId(stageId.getQueryId(), stageId.getStageNumber()), - timeChunk + timeChunk, + sketchEncoding ), taskId, (kernel, snapshot) -> kernel.mergeClusterByStatisticsCollectorForTimeChunk( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java index d6e7d412acad..74f3e780cfea 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/BaseWorkerClientImpl.java @@ -91,14 +91,15 @@ public ListenableFuture postWorkOrder(String workerId, WorkOrder workOrder @Override public ListenableFuture fetchClusterByStatisticsSnapshot( String workerId, - StageId stageId + StageId stageId, + SketchEncoding sketchEncoding ) { String path = StringUtils.format( "/keyStatistics/%s/%d?sketchEncoding=%s", StringUtils.urlEncode(stageId.getQueryId()), stageId.getStageNumber(), - WorkerResource.SketchEncoding.OCTET_STREAM + sketchEncoding ); return getClient(workerId).asyncRequest( @@ -111,7 +112,8 @@ public ListenableFuture fetchClusterByStatisticsSna public ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( String workerId, StageId stageId, - long timeChunk + long timeChunk, + SketchEncoding sketchEncoding ) { String path = StringUtils.format( @@ -119,7 +121,7 @@ public ListenableFuture fetchClusterByStatisticsSna StringUtils.urlEncode(stageId.getQueryId()), stageId.getStageNumber(), timeChunk, - WorkerResource.SketchEncoding.OCTET_STREAM + sketchEncoding ); return getClient(workerId).asyncRequest( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/SketchEncoding.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/SketchEncoding.java new file mode 100644 index 000000000000..11f4450df934 --- /dev/null +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/SketchEncoding.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.rpc; + + +import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde; + +/** + * Determines the encoding of key collectors returned by {@link WorkerResource#httpFetchKeyStatistics} and + * {@link WorkerResource#httpFetchKeyStatisticsWithSnapshot}. + */ +public enum SketchEncoding +{ + /** + * The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}. + */ + OCTET_STREAM, + /** + * The key collector is encoded as json + */ + JSON +} diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java index a0bfecff5427..839defa6bd9c 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/rpc/WorkerResource.java @@ -373,19 +373,4 @@ public Response httpGetCounters(@Context final HttpServletRequest req) return Response.status(Response.Status.OK).entity(worker.getCounters()).build(); } - /** - * Determines the encoding of key collectors returned by {@link #httpFetchKeyStatistics} and - * {@link #httpFetchKeyStatisticsWithSnapshot}. - */ - public enum SketchEncoding - { - /** - * The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}. - */ - OCTET_STREAM, - /** - * The key collector is encoded as json - */ - JSON - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java index 4ed98dca594e..03cec9d192fc 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java @@ -38,6 +38,7 @@ import org.apache.druid.msq.indexing.error.MSQWarnings; import org.apache.druid.msq.kernel.WorkerAssignmentStrategy; import org.apache.druid.msq.rpc.ControllerResource; +import org.apache.druid.msq.rpc.SketchEncoding; import org.apache.druid.msq.sql.MSQMode; import org.apache.druid.query.QueryContext; import org.apache.druid.query.QueryContexts; @@ -137,6 +138,9 @@ public class MultiStageQueryContext public static final String CTX_CLUSTER_STATISTICS_MERGE_MODE = "clusterStatisticsMergeMode"; public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = ClusterStatisticsMergeMode.SEQUENTIAL.toString(); + public static final String CTX_SKETCH_ENCODING_MODE = "sketchEncoding"; + public static final String DEFAULT_CTX_SKETCH_ENCODING_MODE = SketchEncoding.OCTET_STREAM.toString(); + public static final String CTX_ROWS_PER_SEGMENT = "rowsPerSegment"; public static final int DEFAULT_ROWS_PER_SEGMENT = 3000000; @@ -273,6 +277,15 @@ public static ClusterStatisticsMergeMode getClusterStatisticsMergeMode(QueryCont ); } + public static SketchEncoding getSketchEncoding(QueryContext queryContext) + { + return QueryContexts.getAsEnum( + CTX_SKETCH_ENCODING_MODE, + queryContext.getString(CTX_SKETCH_ENCODING_MODE, DEFAULT_CTX_SKETCH_ENCODING_MODE), + SketchEncoding.class + ); + } + public static boolean isFinalizeAggregations(final QueryContext queryContext) { return queryContext.getBoolean( diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java index cba8ede156ce..c9c053b04a3f 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/WorkerSketchFetcherTest.java @@ -27,6 +27,7 @@ import org.apache.druid.msq.kernel.StageDefinition; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.controller.ControllerQueryKernel; +import org.apache.druid.msq.rpc.SketchEncoding; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation; import org.junit.After; @@ -101,13 +102,13 @@ public void test_submitFetcherTask_parallelFetch() throws InterruptedException final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM)); // When fetching snapshots, return a mock and add it to queue doAnswer(invocation -> { ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class); return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any()); + }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), any()); target.inMemoryFullSketchMerging((kernelConsumer) -> { kernelConsumer.accept(kernel); @@ -124,13 +125,13 @@ public void test_submitFetcherTask_sequentialFetch() throws InterruptedException doReturn(true).when(completeKeyStatisticsInformation).isComplete(); final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM)); // When fetching snapshots, return a mock and add it to queue doAnswer(invocation -> { ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class); return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong()); + }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong(), any()); target.sequentialTimeChunkMerging( (kernelConsumer) -> { @@ -152,7 +153,7 @@ public void test_sequentialMerge_nonCompleteInformation() { doReturn(false).when(completeKeyStatisticsInformation).isComplete(); - target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM)); Assert.assertThrows(ISE.class, () -> target.sequentialTimeChunkMerging( (ignore) -> {}, completeKeyStatisticsInformation, @@ -167,7 +168,7 @@ public void test_inMemoryRetryEnabled_retryInvoked() throws InterruptedException { final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM)); workersWithFailedFetchParallel(ImmutableSet.of(TASK_1)); @@ -196,7 +197,7 @@ public void test_SequentialRetryEnabled_retryInvoked() throws InterruptedExcepti doReturn(true).when(completeKeyStatisticsInformation).isComplete(); final CountDownLatch latch = new CountDownLatch(TASK_IDS.size()); - target = spy(new WorkerSketchFetcher(workerClient, workerManager, true)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM)); workersWithFailedFetchSequential(ImmutableSet.of(TASK_1)); CountDownLatch retryLatch = new CountDownLatch(1); @@ -223,7 +224,7 @@ public void test_SequentialRetryEnabled_retryInvoked() throws InterruptedExcepti public void test_InMemoryRetryDisabled_multipleFailures() throws InterruptedException { - target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM)); workersWithFailedFetchParallel(ImmutableSet.of(TASK_1, TASK_0)); @@ -252,7 +253,7 @@ public void test_InMemoryRetryDisabled_multipleFailures() throws InterruptedExce public void test_InMemoryRetryDisabled_singleFailure() throws InterruptedException { - target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM)); workersWithFailedFetchParallel(ImmutableSet.of(TASK_1)); @@ -283,7 +284,7 @@ public void test_SequentialRetryDisabled_multipleFailures() throws InterruptedEx { doReturn(true).when(completeKeyStatisticsInformation).isComplete(); - target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM)); workersWithFailedFetchSequential(ImmutableSet.of(TASK_1, TASK_0)); @@ -315,7 +316,7 @@ public void test_SequentialRetryDisabled_multipleFailures() throws InterruptedEx public void test_SequentialRetryDisabled_singleFailure() throws InterruptedException { doReturn(true).when(completeKeyStatisticsInformation).isComplete(); - target = spy(new WorkerSketchFetcher(workerClient, workerManager, false)); + target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM)); workersWithFailedFetchSequential(ImmutableSet.of(TASK_1)); @@ -352,7 +353,7 @@ private void workersWithFailedFetchSequential(Set failedTasks) return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0))); } return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong()); + }).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong(), any()); } private void workersWithFailedFetchParallel(Set failedTasks) @@ -363,7 +364,7 @@ private void workersWithFailedFetchParallel(Set failedTasks) return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0))); } return Futures.immediateFuture(snapshot); - }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any()); + }).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), any()); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java index 65145b5f5c01..ffd7c67ca2d6 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/test/MSQTestWorkerClient.java @@ -29,6 +29,7 @@ import org.apache.druid.msq.exec.WorkerClient; import org.apache.druid.msq.kernel.StageId; import org.apache.druid.msq.kernel.WorkOrder; +import org.apache.druid.msq.rpc.SketchEncoding; import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot; import java.io.InputStream; @@ -54,7 +55,8 @@ public ListenableFuture postWorkOrder(String workerTaskId, WorkOrder workO @Override public ListenableFuture fetchClusterByStatisticsSnapshot( String workerTaskId, - StageId stageId + StageId stageId, + SketchEncoding sketchEncoding ) { return Futures.immediateFuture(inMemoryWorkers.get(workerTaskId).fetchStatisticsSnapshot(stageId)); @@ -64,7 +66,8 @@ public ListenableFuture fetchClusterByStatisticsSna public ListenableFuture fetchClusterByStatisticsSnapshotForTimeChunk( String workerTaskId, StageId stageId, - long timeChunk + long timeChunk, + SketchEncoding sketchEncoding ) { return Futures.immediateFuture(