From b89b93a9ff0045fd3ab8853bd98d9ddbd0cb4e11 Mon Sep 17 00:00:00 2001 From: Karan Kumar Date: Thu, 12 Oct 2023 14:01:46 +0530 Subject: [PATCH] Limit pages size to a configurable limit (#14994) Adding the ability to limit the pages sizes of select queries. We piggyback on the same machinery that is used to control the numRowsPerSegment. This patch introduces a new context parameter rowsPerPage for which the default value is set to 100000 rows. This patch also optimizes adding the last selectResults stage only when the previous stages have sorted outputs. Currently for each select query with selectDestination=durableStorage, we used to add this extra selectResults stage. --- docs/multi-stage-query/reference.md | 2 +- .../apache/druid/msq/exec/ControllerImpl.java | 43 ++- .../msq/querykit/ShuffleSpecFactories.java | 14 + .../druid/msq/sql/entity/PageInformation.java | 63 +++- .../sql/resources/SqlStatementResource.java | 15 +- .../msq/util/MultiStageQueryContext.java | 12 + .../msq/util/SqlStatementResourceHelper.java | 75 ++-- .../apache/druid/msq/exec/MSQSelectTest.java | 165 ++++++++- .../sql/entity/ResultSetInformationTest.java | 2 +- .../sql/entity/SqlStatementResultTest.java | 2 +- .../SqlMSQStatementResourcePostTest.java | 135 +++++++- .../apache/druid/msq/test/MSQTestBase.java | 53 +-- .../util/SqlStatementResourceHelperTest.java | 322 ++++++++++++++++++ 13 files changed, 802 insertions(+), 101 deletions(-) create mode 100644 extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java diff --git a/docs/multi-stage-query/reference.md b/docs/multi-stage-query/reference.md index 5e80e318b8c88..9ec50de0a9b1a 100644 --- a/docs/multi-stage-query/reference.md +++ b/docs/multi-stage-query/reference.md @@ -248,7 +248,7 @@ The following table lists the context parameters for the MSQ task engine: | `selectDestination` | SELECT

Controls where the final result of the select query is written.
Use `taskReport`(the default) to write select results to the task report. This is not scalable since task reports size explodes for large results
Use `durableStorage` to write results to durable storage location. For large results sets, its recommended to use `durableStorage` . To configure durable storage see [`this`](#durable-storage) section. | `taskReport` | | `waitTillSegmentsLoad` | INSERT, REPLACE

If set, the ingest query waits for the generated segment to be loaded before exiting, else the ingest query exits without waiting. The task and live reports contain the information about the status of loading segments if this flag is set. This will ensure that any future queries made after the ingestion exits will include results from the ingestion. The drawback is that the controller task will stall till the segments are loaded. | `false` | | `includeSegmentSource` | SELECT, INSERT, REPLACE

Controls the sources, which will be queried for results in addition to the segments present on deep storage. Can be `NONE` or `REALTIME`. If this value is `NONE`, only non-realtime (published and used) segments will be downloaded from deep storage. If this value is `REALTIME`, results will also be included from realtime tasks. | `NONE` | - +| `rowsPerPage` | SELECT

The number of rows per page to target. The actual number of rows per page may be somewhat higher or lower than this number. In most cases, use the default.
This property comes into effect only when `selectDestination` is set to `durableStorage` | 100000 | ## Joins Joins in multi-stage queries use one of two algorithms based on what you set the [context parameter](#context-parameters) `sqlJoinAlgorithm` to: diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java index 6f46007d93c04..3dc2e099c5e8e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/exec/ControllerImpl.java @@ -145,7 +145,6 @@ import org.apache.druid.msq.input.table.DataSegmentWithLocation; import org.apache.druid.msq.input.table.TableInputSpec; import org.apache.druid.msq.input.table.TableInputSpecSlicer; -import org.apache.druid.msq.kernel.GlobalSortTargetSizeShuffleSpec; import org.apache.druid.msq.kernel.QueryDefinition; import org.apache.druid.msq.kernel.QueryDefinitionBuilder; import org.apache.druid.msq.kernel.StageDefinition; @@ -1663,12 +1662,7 @@ private static QueryDefinition makeQueryDefinition( final ShuffleSpecFactory shuffleSpecFactory; if (MSQControllerTask.isIngestion(querySpec)) { - shuffleSpecFactory = (clusterBy, aggregate) -> - new GlobalSortTargetSizeShuffleSpec( - clusterBy, - tuningConfig.getRowsPerSegment(), - aggregate - ); + shuffleSpecFactory = ShuffleSpecFactories.getGlobalSortWithTargetSize(tuningConfig.getRowsPerSegment()); if (!columnMappings.hasUniqueOutputColumnNames()) { // We do not expect to hit this case in production, because the SQL validator checks that column names @@ -1693,8 +1687,9 @@ private static QueryDefinition makeQueryDefinition( shuffleSpecFactory = ShuffleSpecFactories.singlePartition(); queryToPlan = querySpec.getQuery(); } else if (querySpec.getDestination() instanceof DurableStorageMSQDestination) { - // we add a final stage which generates one partition per worker. - shuffleSpecFactory = ShuffleSpecFactories.globalSortWithMaxPartitionCount(tuningConfig.getMaxNumWorkers()); + shuffleSpecFactory = ShuffleSpecFactories.getGlobalSortWithTargetSize( + MultiStageQueryContext.getRowsPerPage(querySpec.getQuery().context()) + ); queryToPlan = querySpec.getQuery(); } else { throw new ISE("Unsupported destination [%s]", querySpec.getDestination()); @@ -1772,27 +1767,29 @@ private static QueryDefinition makeQueryDefinition( return queryDef; } else if (querySpec.getDestination() instanceof DurableStorageMSQDestination) { - // attaching new query results stage always. + // attaching new query results stage if the final stage does sort during shuffle so that results are ordered. StageDefinition finalShuffleStageDef = queryDef.getFinalStageDefinition(); - final QueryDefinitionBuilder builder = QueryDefinition.builder(); - for (final StageDefinition stageDef : queryDef.getStageDefinitions()) { - builder.add(StageDefinition.builder(stageDef)); + if (finalShuffleStageDef.doesSortDuringShuffle()) { + final QueryDefinitionBuilder builder = QueryDefinition.builder(); + builder.addAll(queryDef); + builder.add(StageDefinition.builder(queryDef.getNextStageNumber()) + .inputs(new StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber())) + .maxWorkerCount(tuningConfig.getMaxNumWorkers()) + .signature(finalShuffleStageDef.getSignature()) + .shuffleSpec(null) + .processorFactory(new QueryResultFrameProcessorFactory()) + ); + return builder.build(); + } else { + return queryDef; } - - builder.add(StageDefinition.builder(queryDef.getNextStageNumber()) - .inputs(new StageInputSpec(queryDef.getFinalStageDefinition().getStageNumber())) - .maxWorkerCount(tuningConfig.getMaxNumWorkers()) - .signature(finalShuffleStageDef.getSignature()) - .shuffleSpec(null) - .processorFactory(new QueryResultFrameProcessorFactory()) - ); - - return builder.build(); } else { throw new ISE("Unsupported destination [%s]", querySpec.getDestination()); } } + + private static DataSchema generateDataSchema( MSQSpec querySpec, RowSignature querySignature, diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java index 971aa9b7e0c7f..d28439c0f8e0e 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/querykit/ShuffleSpecFactories.java @@ -20,6 +20,7 @@ package org.apache.druid.msq.querykit; import org.apache.druid.msq.kernel.GlobalSortMaxCountShuffleSpec; +import org.apache.druid.msq.kernel.GlobalSortTargetSizeShuffleSpec; import org.apache.druid.msq.kernel.MixShuffleSpec; /** @@ -53,4 +54,17 @@ public static ShuffleSpecFactory globalSortWithMaxPartitionCount(final int parti { return (clusterBy, aggregate) -> new GlobalSortMaxCountShuffleSpec(clusterBy, partitions, aggregate); } + + /** + * Factory that produces globally sorted partitions of a target size. + */ + public static ShuffleSpecFactory getGlobalSortWithTargetSize(int targetSize) + { + return (clusterBy, aggregate) -> + new GlobalSortTargetSizeShuffleSpec( + clusterBy, + targetSize, + aggregate + ); + } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/entity/PageInformation.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/entity/PageInformation.java index 6db1f371af7ca..f50716f108cd7 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/entity/PageInformation.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/entity/PageInformation.java @@ -21,11 +21,11 @@ import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import javax.annotation.Nullable; -import java.util.Comparator; import java.util.Objects; /** @@ -39,6 +39,14 @@ public class PageInformation @Nullable private final Long sizeInBytes; + // Worker field should not flow to the users of SqlStatementResource API since users should not care about worker + @Nullable + private final Integer worker; + + // Partition field should not flow to the users of SqlStatementResource API since users should not care about partitions + @Nullable + private final Integer partition; + @JsonCreator public PageInformation( @JsonProperty("id") long id, @@ -49,8 +57,27 @@ public PageInformation( this.id = id; this.numRows = numRows; this.sizeInBytes = sizeInBytes; + this.worker = null; + this.partition = null; } + + public PageInformation( + long id, + Long numRows, + Long sizeInBytes, + Integer worker, + Integer partition + ) + { + this.id = id; + this.numRows = numRows; + this.sizeInBytes = sizeInBytes; + this.worker = worker; + this.partition = partition; + } + + @JsonProperty public long getId() { @@ -74,6 +101,20 @@ public Long getSizeInBytes() } + @Nullable + @JsonIgnore + public Integer getWorker() + { + return worker; + } + + @Nullable + @JsonIgnore + public Integer getPartition() + { + return partition; + } + @Override public boolean equals(Object o) { @@ -87,13 +128,13 @@ public boolean equals(Object o) return id == that.id && Objects.equals(numRows, that.numRows) && Objects.equals( sizeInBytes, that.sizeInBytes - ); + ) && Objects.equals(worker, that.worker) && Objects.equals(partition, that.partition); } @Override public int hashCode() { - return Objects.hash(id, numRows, sizeInBytes); + return Objects.hash(id, numRows, sizeInBytes, worker, partition); } @Override @@ -103,20 +144,8 @@ public String toString() "id=" + id + ", numRows=" + numRows + ", sizeInBytes=" + sizeInBytes + + ", worker=" + worker + + ", partition=" + partition + '}'; } - - public static Comparator getIDComparator() - { - return new PageComparator(); - } - - public static class PageComparator implements Comparator - { - @Override - public int compare(PageInformation s1, PageInformation s2) - { - return Long.compare(s1.getId(), s2.getId()); - } - } } diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java index dd4e084030064..91145985ee13f 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/sql/resources/SqlStatementResource.java @@ -504,7 +504,7 @@ private Response buildNonOkResponse(DruidException exception) } @SuppressWarnings("ReassignedVariable") - private Optional getSampleResults( + private Optional getResultSetInformation( String queryId, String dataSource, SqlStatementState sqlStatementState, @@ -617,7 +617,7 @@ private Optional getStatementStatus( taskResponse.getStatus().getCreatedTime(), signature.orElse(null), taskResponse.getStatus().getDuration(), - withResults ? getSampleResults( + withResults ? getResultSetInformation( queryId, msqControllerTask.getDataSource(), sqlStatementState, @@ -782,11 +782,16 @@ private Optional> getResultYielder( || selectedPageId.equals(pageInformation.getId())) .map(pageInformation -> { try { + if (pageInformation.getWorker() == null || pageInformation.getPartition() == null) { + throw DruidException.defensive( + "Worker or partition number is null for page id [%d]", + pageInformation.getId() + ); + } return new FrameChannelSequence(standardImplementation.openChannel( finalStage.getId(), - (int) pageInformation.getId(), - (int) pageInformation.getId() - // we would always have partition number == worker number + pageInformation.getWorker(), + pageInformation.getPartition() )); } catch (Exception e) { diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java index 613fac6203c2f..77b11a2876873 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/MultiStageQueryContext.java @@ -115,6 +115,9 @@ public class MultiStageQueryContext public static final String CTX_ROWS_PER_SEGMENT = "rowsPerSegment"; static final int DEFAULT_ROWS_PER_SEGMENT = 3000000; + public static final String CTX_ROWS_PER_PAGE = "rowsPerPage"; + static final int DEFAULT_ROWS_PER_PAGE = 100000; + public static final String CTX_ROWS_IN_MEMORY = "rowsInMemory"; // Lower than the default to minimize the impact of per-row overheads that are not accounted for by // OnheapIncrementalIndex. For example: overheads related to creating bitmaps during persist. @@ -238,6 +241,15 @@ public static int getRowsPerSegment(final QueryContext queryContext) ); } + public static int getRowsPerPage(final QueryContext queryContext) + { + return queryContext.getInt( + CTX_ROWS_PER_PAGE, + DEFAULT_ROWS_PER_PAGE + ); + } + + public static MSQSelectDestination getSelectDestination(final QueryContext queryContext) { return QueryContexts.getAsEnum( diff --git a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java index 86aed98f063e5..9481fc60541b2 100644 --- a/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java +++ b/extensions-core/multi-stage-query/src/main/java/org/apache/druid/msq/util/SqlStatementResourceHelper.java @@ -143,8 +143,13 @@ public static SqlStatementState getSqlStatementState(TaskStatusPlus taskStatusPl *
    *
  1. {@link DataSourceMSQDestination} a single page is returned which adds all the counters of {@link SegmentGenerationProgressCounter.Snapshot}
  2. *
  3. {@link TaskReportMSQDestination} a single page is returned which adds all the counters of {@link ChannelCounters}
  4. - *
  5. {@link DurableStorageMSQDestination} a page is returned for each worker which has generated output rows. The list is sorted on page Id. - * If the worker generated 0 rows, we do no populated a page for it. {@link PageInformation#id} is equal to the worker number
  6. + *
  7. {@link DurableStorageMSQDestination} a page is returned for each partition, worker which has generated output rows. The pages are populated in the following order: + *
      + *
    • For each partition from 0 to N
    • + *
    • For each worker from 0 to M
    • + *
    • If num rows for that partition,worker combination is 0, create a page
    • + * so that we maintain the record ordering. + *
    *
*/ public static Optional> populatePageList( @@ -155,9 +160,9 @@ public static Optional> populatePageList( if (msqTaskReportPayload.getStages() == null || msqTaskReportPayload.getCounters() == null) { return Optional.empty(); } - int finalStage = msqTaskReportPayload.getStages().getStages().size() - 1; + MSQStagesReport.Stage finalStage = getFinalStage(msqTaskReportPayload); CounterSnapshotsTree counterSnapshotsTree = msqTaskReportPayload.getCounters(); - Map workerCounters = counterSnapshotsTree.snapshotForStage(finalStage); + Map workerCounters = counterSnapshotsTree.snapshotForStage(finalStage.getStageNumber()); if (workerCounters == null || workerCounters.isEmpty()) { return Optional.empty(); } @@ -193,27 +198,56 @@ public static Optional> populatePageList( } } else if (msqDestination instanceof DurableStorageMSQDestination) { - List pageList = new ArrayList<>(); - for (Map.Entry counterSnapshots : workerCounters.entrySet()) { - long rows = 0L; - long size = 0L; - QueryCounterSnapshot queryCounterSnapshot = counterSnapshots.getValue().getMap().getOrDefault("output", null); - if (queryCounterSnapshot != null && queryCounterSnapshot instanceof ChannelCounters.Snapshot) { - rows += Arrays.stream(((ChannelCounters.Snapshot) queryCounterSnapshot).getRows()).sum(); - size += Arrays.stream(((ChannelCounters.Snapshot) queryCounterSnapshot).getBytes()).sum(); - } - // do not populate a page if the worker generated 0 rows. - if (rows != 0L) { - pageList.add(new PageInformation(counterSnapshots.getKey(), rows, size)); - } - } - pageList.sort(PageInformation.getIDComparator()); - return Optional.of(pageList); + + return populatePagesForDurableStorageDestination(finalStage, workerCounters); } else { return Optional.empty(); } } + private static Optional> populatePagesForDurableStorageDestination( + MSQStagesReport.Stage finalStage, + Map workerCounters + ) + { + // figure out number of partitions and number of workers + int totalPartitions = finalStage.getPartitionCount(); + int totalWorkerCount = finalStage.getWorkerCount(); + + if (totalPartitions == -1) { + throw DruidException.defensive("Expected partition count to be set for stage[%d]", finalStage); + } + if (totalWorkerCount == -1) { + throw DruidException.defensive("Expected worker count to be set for stage[%d]", finalStage); + } + + + List pages = new ArrayList<>(); + for (int partitionNumber = 0; partitionNumber < totalPartitions; partitionNumber++) { + for (int workerNumber = 0; workerNumber < totalWorkerCount; workerNumber++) { + CounterSnapshots workerCounter = workerCounters.get(workerNumber); + + if (workerCounter != null && workerCounter.getMap() != null) { + QueryCounterSnapshot channelCounters = workerCounter.getMap().get("output"); + + if (channelCounters != null && channelCounters instanceof ChannelCounters.Snapshot) { + long rows = 0L; + long size = 0L; + + if (((ChannelCounters.Snapshot) channelCounters).getRows().length > partitionNumber) { + rows += ((ChannelCounters.Snapshot) channelCounters).getRows()[partitionNumber]; + size += ((ChannelCounters.Snapshot) channelCounters).getBytes()[partitionNumber]; + } + if (rows != 0L) { + pages.add(new PageInformation(pages.size(), rows, size, workerNumber, partitionNumber)); + } + } + } + } + } + return Optional.of(pages); + } + public static Optional getExceptionPayload( String queryId, TaskStatusResponse taskResponse, @@ -336,6 +370,7 @@ public static MSQStagesReport.Stage getFinalStage(MSQTaskReportPayload msqTaskRe } return null; } + public static Map getQueryExceptionDetails(Map payload) { return getMap(getMap(payload, "status"), "errorReport"); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java index f08311997d995..31ea643751990 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/exec/MSQSelectTest.java @@ -62,6 +62,7 @@ import org.apache.druid.query.aggregation.post.FieldAccessPostAggregator; import org.apache.druid.query.dimension.DefaultDimensionSpec; import org.apache.druid.query.expression.TestExprMacroTable; +import org.apache.druid.query.filter.LikeDimFilter; import org.apache.druid.query.groupby.GroupByQuery; import org.apache.druid.query.groupby.orderby.DefaultLimitSpec; import org.apache.druid.query.groupby.orderby.OrderByColumnSpec; @@ -111,6 +112,7 @@ public class MSQSelectTest extends MSQTestBase public static final Map QUERY_RESULTS_WITH_DURABLE_STORAGE_CONTEXT = ImmutableMap.builder() .putAll(DURABLE_STORAGE_MSQ_CONTEXT) + .put(MultiStageQueryContext.CTX_ROWS_PER_PAGE, 2) .put( MultiStageQueryContext.CTX_SELECT_DESTINATION, MSQSelectDestination.DURABLESTORAGE.getName().toLowerCase(Locale.ENGLISH) @@ -225,7 +227,9 @@ public void testSelectOnFoo() ) .setExpectedCountersForStageWorkerChannel( CounterSnapshotMatcher - .with().rows(6).frames(1), + .with() + .rows(isPageSizeLimited() ? new long[]{1, 2, 3} : new long[]{6}) + .frames(isPageSizeLimited() ? new long[]{1, 1, 1} : new long[]{1}), 0, 0, "shuffle" ) .setExpectedResultRows(ImmutableList.of( @@ -343,15 +347,16 @@ public void testSelectOnFooDuplicateColumnNames() CounterSnapshotMatcher .with().totalFiles(1), 0, 0, "input0" - ) - .setExpectedCountersForStageWorkerChannel( + ).setExpectedCountersForStageWorkerChannel( CounterSnapshotMatcher .with().rows(6).frames(1), 0, 0, "output" ) .setExpectedCountersForStageWorkerChannel( CounterSnapshotMatcher - .with().rows(6).frames(1), + .with() + .rows(isPageSizeLimited() ? new long[]{1, 2, 3} : new long[]{6}) + .frames(isPageSizeLimited() ? new long[]{1, 1, 1} : new long[]{1}), 0, 0, "shuffle" ) .setExpectedResultRows(ImmutableList.of( @@ -1253,7 +1258,7 @@ public void testGroupByOrderByAggregationWithLimitAndOffset() } @Test - public void testExternSelect1() throws IOException + public void testExternGroupBy() throws IOException { final File toRead = MSQTestFileUtils.getResourceAsTemporaryFile(temporaryFolder, this, "/wikipedia-sampled.json"); final String toReadAsJson = queryFramework().queryJsonMapper().writeValueAsString(toRead.getAbsolutePath()); @@ -1339,6 +1344,151 @@ public void testExternSelect1() throws IOException .verifyResults(); } + + @Test + public void testExternSelectWithMultipleWorkers() throws IOException + { + Map multipleWorkerContext = new HashMap<>(context); + multipleWorkerContext.put(MultiStageQueryContext.CTX_MAX_NUM_TASKS, 3); + + final File toRead = MSQTestFileUtils.getResourceAsTemporaryFile(temporaryFolder, this, "/wikipedia-sampled.json"); + final String toReadAsJson = queryFramework().queryJsonMapper().writeValueAsString(toRead.getAbsolutePath()); + + RowSignature rowSignature = RowSignature.builder() + .add("__time", ColumnType.LONG) + .add("user", ColumnType.STRING) + .build(); + + final ScanQuery expectedQuery = + newScanQueryBuilder().dataSource( + new ExternalDataSource( + new LocalInputSource(null, null, ImmutableList.of(toRead.getAbsoluteFile(), toRead.getAbsoluteFile())), + new JsonInputFormat(null, null, null, null, null), + RowSignature.builder() + .add("timestamp", ColumnType.STRING) + .add("page", ColumnType.STRING) + .add("user", ColumnType.STRING) + .build() + ) + ).eternityInterval().virtualColumns( + new ExpressionVirtualColumn( + "v0", + "timestamp_floor(timestamp_parse(\"timestamp\",null,'UTC'),'P1D',null,'UTC')", + ColumnType.LONG, + CalciteTests.createExprMacroTable() + ) + ).columns("user", "v0").filters(new LikeDimFilter("user", "%ot%", null, null)) + .context(defaultScanQueryContext(multipleWorkerContext, RowSignature.builder() + .add( + "user", + ColumnType.STRING + ) + .add( + "v0", + ColumnType.LONG + ) + .build())) + .build(); + + SelectTester selectTester = testSelectQuery() + .setSql("SELECT\n" + + " floor(TIME_PARSE(\"timestamp\") to day) AS __time,\n" + + " user\n" + + "FROM TABLE(\n" + + " EXTERN(\n" + + " '{ \"files\": [" + toReadAsJson + "," + toReadAsJson + "],\"type\":\"local\"}',\n" + + " '{\"type\": \"json\"}',\n" + + " '[{\"name\": \"timestamp\", \"type\": \"string\"}, {\"name\": \"page\", \"type\": \"string\"}, {\"name\": \"user\", \"type\": \"string\"}]'\n" + + " )\n" + + ") where user like '%ot%'") + .setExpectedRowSignature(rowSignature) + .setQueryContext(multipleWorkerContext) + .setExpectedResultRows(ImmutableList.of( + new Object[]{1466985600000L, "Lsjbot"}, + new Object[]{1466985600000L, "Lsjbot"}, + new Object[]{1466985600000L, "Beau.bot"}, + new Object[]{1466985600000L, "Beau.bot"}, + new Object[]{1466985600000L, "Lsjbot"}, + new Object[]{1466985600000L, "Lsjbot"}, + new Object[]{1466985600000L, "TaxonBot"}, + new Object[]{1466985600000L, "TaxonBot"}, + new Object[]{1466985600000L, "GiftBot"}, + new Object[]{1466985600000L, "GiftBot"} + )) + .setExpectedMSQSpec( + MSQSpec + .builder() + .query(expectedQuery) + .columnMappings(new ColumnMappings( + ImmutableList.of( + new ColumnMapping("v0", "__time"), + new ColumnMapping("user", "user") + ) + )) + .tuningConfig(MSQTuningConfig.defaultConfig()) + .destination(isDurableStorageDestination() + ? DurableStorageMSQDestination.INSTANCE + : TaskReportMSQDestination.INSTANCE) + .build() + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with().rows(20).bytes(toRead.length()).files(1).totalFiles(1), + 0, 0, "input0" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with().rows(5).frames(1), + 0, 0, "output" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with() + .rows(isPageSizeLimited() ? new long[]{1L, 1L, 1L, 2L} : new long[]{5L}) + .frames(isPageSizeLimited() ? new long[]{1L, 1L, 1L, 1L} : new long[]{1L}), + 0, 0, "shuffle" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with().rows(20).bytes(toRead.length()).files(1).totalFiles(1), + 0, 1, "input0" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with().rows(5).frames(1), + 0, 1, "output" + ) + .setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with() + .rows(isPageSizeLimited() ? new long[]{1L, 1L, 1L, 2L} : new long[]{5L}) + .frames(isPageSizeLimited() ? new long[]{1L, 1L, 1L, 1L} : new long[]{1L}), + 0, 1, "shuffle" + ); + // adding result stage counter checks + if (isPageSizeLimited()) { + selectTester = selectTester.setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with().rows(2, 0, 2).frames(1, 0, 1), + 1, 0, "input0" + ).setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with().rows(2, 0, 2).frames(1, 0, 1), + 1, 0, "output" + ); + selectTester = selectTester.setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with().rows(0, 2, 0, 4).frames(0, 1, 0, 1), + 1, 1, "input0" + ).setExpectedCountersForStageWorkerChannel( + CounterSnapshotMatcher + .with().rows(0, 2, 0, 4).frames(0, 1, 0, 1), + 1, 1, "output" + ); + } + selectTester.verifyResults(); + } + @Test public void testIncorrectSelectQuery() { @@ -2434,4 +2584,9 @@ public boolean isDurableStorageDestination() { return QUERY_RESULTS_WITH_DURABLE_STORAGE.equals(contextName) || QUERY_RESULTS_WITH_DEFAULT_CONTEXT.equals(context); } + + public boolean isPageSizeLimited() + { + return QUERY_RESULTS_WITH_DURABLE_STORAGE.equals(contextName); + } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/entity/ResultSetInformationTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/entity/ResultSetInformationTest.java index ce84ac91fd405..0d3ca30b0f5a1 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/entity/ResultSetInformationTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/entity/ResultSetInformationTest.java @@ -66,7 +66,7 @@ public void sanityTest() throws JsonProcessingException MAPPER.readValue(MAPPER.writeValueAsString(RESULTS), ResultSetInformation.class).hashCode() ); Assert.assertEquals( - "ResultSetInformation{numTotalRows=1, totalSizeInBytes=1, resultFormat=object, records=null, dataSource='ds', pages=[PageInformation{id=0, numRows=null, sizeInBytes=1}]}", + "ResultSetInformation{numTotalRows=1, totalSizeInBytes=1, resultFormat=object, records=null, dataSource='ds', pages=[PageInformation{id=0, numRows=null, sizeInBytes=1, worker=null, partition=null}]}", RESULTS.toString() ); } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/entity/SqlStatementResultTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/entity/SqlStatementResultTest.java index 0434c89ce193c..03c017b7442d9 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/entity/SqlStatementResultTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/entity/SqlStatementResultTest.java @@ -86,7 +86,7 @@ public void sanityTest() throws JsonProcessingException + " createdAt=2023-05-31T12:00:00.000Z," + " sqlRowSignature=[ColumnNameAndTypes{colName='_time', sqlTypeName='TIMESTAMP', nativeTypeName='LONG'}, ColumnNameAndTypes{colName='alias', sqlTypeName='VARCHAR', nativeTypeName='STRING'}, ColumnNameAndTypes{colName='market', sqlTypeName='VARCHAR', nativeTypeName='STRING'}]," + " durationInMs=100," - + " resultSetInformation=ResultSetInformation{numTotalRows=1, totalSizeInBytes=1, resultFormat=object, records=null, dataSource='ds', pages=[PageInformation{id=0, numRows=null, sizeInBytes=1}]}," + + " resultSetInformation=ResultSetInformation{numTotalRows=1, totalSizeInBytes=1, resultFormat=object, records=null, dataSource='ds', pages=[PageInformation{id=0, numRows=null, sizeInBytes=1, worker=null, partition=null}]}," + " errorResponse={error=druidException, errorCode=QueryNotSupported, persona=USER, category=UNCATEGORIZED, errorMessage=QueryNotSupported, context={}}}", SQL_STATEMENT_RESULT.toString() ); diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlMSQStatementResourcePostTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlMSQStatementResourcePostTest.java index 415e36a02d498..6650c77855552 100644 --- a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlMSQStatementResourcePostTest.java +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/sql/resources/SqlMSQStatementResourcePostTest.java @@ -39,6 +39,7 @@ import org.apache.druid.msq.sql.entity.ResultSetInformation; import org.apache.druid.msq.sql.entity.SqlStatementResult; import org.apache.druid.msq.test.MSQTestBase; +import org.apache.druid.msq.test.MSQTestFileUtils; import org.apache.druid.msq.test.MSQTestOverlordServiceClient; import org.apache.druid.msq.util.MultiStageQueryContext; import org.apache.druid.query.ExecutionMode; @@ -55,6 +56,7 @@ import javax.ws.rs.core.Response; import javax.ws.rs.core.StreamingOutput; import java.io.ByteArrayOutputStream; +import java.io.File; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -307,6 +309,7 @@ public void testWithDurableStorage() throws IOException { Map context = defaultAsyncContext(); context.put(MultiStageQueryContext.CTX_SELECT_DESTINATION, MSQSelectDestination.DURABLESTORAGE.getName()); + context.put(MultiStageQueryContext.CTX_ROWS_PER_PAGE, 2); SqlStatementResult sqlStatementResult = (SqlStatementResult) resource.doPost( new SqlQuery( @@ -321,6 +324,12 @@ public void testWithDurableStorage() throws IOException SqlStatementResourceTest.makeOkRequest() ).getEntity(); + Assert.assertEquals(ImmutableList.of( + new PageInformation(0, 1L, 75L, 0, 0), + new PageInformation(1, 2L, 121L, 0, 1), + new PageInformation(2, 3L, 164L, 0, 2) + ), sqlStatementResult.getResultSetInformation().getPages()); + assertExpectedResults( "{\"cnt\":1,\"dim1\":\"\"}\n" + "{\"cnt\":1,\"dim1\":\"10.1\"}\n" @@ -335,23 +344,125 @@ public void testWithDurableStorage() throws IOException ResultFormat.OBJECTLINES.name(), SqlStatementResourceTest.makeOkRequest() ), - objectMapper); + objectMapper + ); assertExpectedResults( - "{\"cnt\":1,\"dim1\":\"\"}\n" - + "{\"cnt\":1,\"dim1\":\"10.1\"}\n" - + "{\"cnt\":1,\"dim1\":\"2\"}\n" - + "{\"cnt\":1,\"dim1\":\"1\"}\n" + "{\"cnt\":1,\"dim1\":\"\"}\n\n", + resource.doGetResults( + sqlStatementResult.getQueryId(), + 0L, + ResultFormat.OBJECTLINES.name(), + SqlStatementResourceTest.makeOkRequest() + ), + objectMapper + ); + + assertExpectedResults( + "{\"cnt\":1,\"dim1\":\"1\"}\n" + "{\"cnt\":1,\"dim1\":\"def\"}\n" + "{\"cnt\":1,\"dim1\":\"abc\"}\n" + "\n", resource.doGetResults( sqlStatementResult.getQueryId(), - 0L, + 2L, ResultFormat.OBJECTLINES.name(), SqlStatementResourceTest.makeOkRequest() ), - objectMapper); + objectMapper + ); + } + + + @Test + public void testMultipleWorkersWithPageSizeLimiting() throws IOException + { + Map context = defaultAsyncContext(); + context.put(MultiStageQueryContext.CTX_SELECT_DESTINATION, MSQSelectDestination.DURABLESTORAGE.getName()); + context.put(MultiStageQueryContext.CTX_ROWS_PER_PAGE, 2); + context.put(MultiStageQueryContext.CTX_MAX_NUM_TASKS, 3); + + final File toRead = MSQTestFileUtils.getResourceAsTemporaryFile(temporaryFolder, this, "/wikipedia-sampled.json"); + final String toReadAsJson = queryFramework().queryJsonMapper().writeValueAsString(toRead.getAbsolutePath()); + + + SqlStatementResult sqlStatementResult = (SqlStatementResult) resource.doPost( + new SqlQuery( + "SELECT\n" + + " floor(TIME_PARSE(\"timestamp\") to day) AS __time,\n" + + " user\n" + + "FROM TABLE(\n" + + " EXTERN(\n" + + " '{ \"files\": [" + toReadAsJson + "," + toReadAsJson + "],\"type\":\"local\"}',\n" + + " '{\"type\": \"json\"}',\n" + + " '[{\"name\": \"timestamp\", \"type\": \"string\"}, {\"name\": \"page\", \"type\": \"string\"}, {\"name\": \"user\", \"type\": \"string\"}]'\n" + + " )\n" + + ") where user like '%ot%'", + null, + false, + false, + false, + context, + null + ), + SqlStatementResourceTest.makeOkRequest() + ).getEntity(); + + Assert.assertEquals(ImmutableList.of( + new PageInformation(0, 2L, 128L, 0, 0), + new PageInformation(1, 2L, 132L, 1, 1), + new PageInformation(2, 2L, 128L, 0, 2), + new PageInformation(3, 4L, 228L, 1, 3) + ), sqlStatementResult.getResultSetInformation().getPages()); + + + List> rows = new ArrayList<>(); + rows.add(ImmutableList.of(1466985600000L, "Lsjbot")); + rows.add(ImmutableList.of(1466985600000L, "Lsjbot")); + rows.add(ImmutableList.of(1466985600000L, "Beau.bot")); + rows.add(ImmutableList.of(1466985600000L, "Beau.bot")); + rows.add(ImmutableList.of(1466985600000L, "Lsjbot")); + rows.add(ImmutableList.of(1466985600000L, "Lsjbot")); + rows.add(ImmutableList.of(1466985600000L, "TaxonBot")); + rows.add(ImmutableList.of(1466985600000L, "TaxonBot")); + rows.add(ImmutableList.of(1466985600000L, "GiftBot")); + rows.add(ImmutableList.of(1466985600000L, "GiftBot")); + + + Assert.assertEquals(rows, SqlStatementResourceTest.getResultRowsFromResponse(resource.doGetResults( + sqlStatementResult.getQueryId(), + null, + ResultFormat.ARRAY.name(), + SqlStatementResourceTest.makeOkRequest() + ))); + + Assert.assertEquals(rows.subList(0, 2), SqlStatementResourceTest.getResultRowsFromResponse(resource.doGetResults( + sqlStatementResult.getQueryId(), + 0L, + ResultFormat.ARRAY.name(), + SqlStatementResourceTest.makeOkRequest() + ))); + + Assert.assertEquals(rows.subList(2, 4), SqlStatementResourceTest.getResultRowsFromResponse(resource.doGetResults( + sqlStatementResult.getQueryId(), + 1L, + ResultFormat.ARRAY.name(), + SqlStatementResourceTest.makeOkRequest() + ))); + + Assert.assertEquals(rows.subList(4, 6), SqlStatementResourceTest.getResultRowsFromResponse(resource.doGetResults( + sqlStatementResult.getQueryId(), + 2L, + ResultFormat.ARRAY.name(), + SqlStatementResourceTest.makeOkRequest() + ))); + + Assert.assertEquals(rows.subList(6, 10), SqlStatementResourceTest.getResultRowsFromResponse(resource.doGetResults( + sqlStatementResult.getQueryId(), + 3L, + ResultFormat.ARRAY.name(), + SqlStatementResourceTest.makeOkRequest() + ))); } @Test @@ -457,7 +568,12 @@ public void testResultFormatWithParamInSelect() throws IOException ))); } - private byte[] createExpectedResultsInFormat(ResultFormat resultFormat, List resultsList, List rowSignature, ObjectMapper jsonMapper) throws Exception + private byte[] createExpectedResultsInFormat( + ResultFormat resultFormat, + List resultsList, + List rowSignature, + ObjectMapper jsonMapper + ) throws Exception { ByteArrayOutputStream os = new ByteArrayOutputStream(); try (final ResultFormat.Writer writer = resultFormat.createFormatter(os, jsonMapper)) { @@ -466,7 +582,8 @@ private byte[] createExpectedResultsInFormat(ResultFormat resultFormat, List, List>> } MSQTaskReportPayload payload = getPayloadOrThrow(controllerId); - verifyCounters(payload.getCounters()); - verifyWorkerCount(payload.getCounters()); - if (payload.getStatus().getErrorReport() != null) { throw new ISE("Query %s failed due to %s", sql, payload.getStatus().getErrorReport().toString()); @@ -1365,24 +1364,36 @@ public Pair, List>> } else { StageDefinition finalStage = Objects.requireNonNull(SqlStatementResourceHelper.getFinalStage( payload)).getStageDefinition(); - Closer closer = Closer.create(); - InputChannelFactory inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation( - controllerId, - localFileStorageConnector, - closer, - true + + Optional> pages = SqlStatementResourceHelper.populatePageList( + payload, + spec.getDestination() ); - rows = new FrameChannelSequence(inputChannelFactory.openChannel( - finalStage.getId(), - 0, - 0 - )).flatMap(frame -> SqlStatementResourceHelper.getResultSequence( - msqControllerTask, - finalStage, - frame, - objectMapper - )).withBaggage(closer).toList(); + if (!pages.isPresent()) { + throw new ISE("No query results found"); + } + + rows = new ArrayList<>(); + for (PageInformation pageInformation : pages.get()) { + Closer closer = Closer.create(); + InputChannelFactory inputChannelFactory = DurableStorageInputChannelFactory.createStandardImplementation( + controllerId, + localFileStorageConnector, + closer, + true + ); + rows.addAll(new FrameChannelSequence(inputChannelFactory.openChannel( + finalStage.getId(), + pageInformation.getWorker(), + pageInformation.getPartition() + )).flatMap(frame -> SqlStatementResourceHelper.getResultSequence( + msqControllerTask, + finalStage, + frame, + objectMapper + )).withBaggage(closer).toList()); + } } if (rows == null) { throw new ISE("Query successful but no results found"); @@ -1395,6 +1406,10 @@ public Pair, List>> } log.info("Found spec: %s", objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(spec)); + + verifyCounters(payload.getCounters()); + verifyWorkerCount(payload.getCounters()); + return new Pair<>(spec, Pair.of(payload.getResults().getSignature(), rows)); } } diff --git a/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java new file mode 100644 index 0000000000000..806bd8ebe9888 --- /dev/null +++ b/extensions-core/multi-stage-query/src/test/java/org/apache/druid/msq/util/SqlStatementResourceHelperTest.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.druid.msq.util; + +import com.google.common.collect.ImmutableMap; +import org.apache.druid.frame.Frame; +import org.apache.druid.indexer.TaskState; +import org.apache.druid.java.util.common.Pair; +import org.apache.druid.java.util.common.logger.Logger; +import org.apache.druid.msq.counters.ChannelCounters; +import org.apache.druid.msq.counters.CounterSnapshots; +import org.apache.druid.msq.counters.CounterSnapshotsTree; +import org.apache.druid.msq.indexing.destination.DurableStorageMSQDestination; +import org.apache.druid.msq.indexing.report.MSQStagesReport; +import org.apache.druid.msq.indexing.report.MSQStatusReport; +import org.apache.druid.msq.indexing.report.MSQTaskReportPayload; +import org.apache.druid.msq.indexing.report.MSQTaskReportTest; +import org.apache.druid.msq.sql.entity.PageInformation; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Test; + +import java.util.ArrayDeque; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TreeMap; + +public class SqlStatementResourceHelperTest +{ + + private static final Logger log = new Logger(SqlStatementResourceHelperTest.class); + + @Test + public void testDistinctPartitionsOnEachWorker() + { + CounterSnapshotsTree counterSnapshots = new CounterSnapshotsTree(); + ChannelCounters worker0 = createChannelCounters(new int[]{0, 3, 6}); + ChannelCounters worker1 = createChannelCounters(new int[]{1, 4, 4, 7, 9, 10, 13}); + ChannelCounters worker2 = createChannelCounters(new int[]{2, 5, 8, 11, 14}); + + counterSnapshots.put(0, 0, new CounterSnapshots(ImmutableMap.of("output", worker0.snapshot()))); + counterSnapshots.put(0, 1, new CounterSnapshots(ImmutableMap.of("output", worker1.snapshot()))); + counterSnapshots.put(0, 2, new CounterSnapshots(ImmutableMap.of("output", worker2.snapshot()))); + + MSQTaskReportPayload payload = new MSQTaskReportPayload(new MSQStatusReport( + TaskState.SUCCESS, + null, + new ArrayDeque<>(), + null, + 0, + 1, + 2, + null + ), MSQStagesReport.create( + MSQTaskReportTest.QUERY_DEFINITION, + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(0, 3), + ImmutableMap.of(0, 15) + ), counterSnapshots, null); + + Optional> pages = SqlStatementResourceHelper.populatePageList( + payload, + DurableStorageMSQDestination.instance() + ); + validatePages(pages.get(), createValidationMap(worker0, worker1, worker2)); + } + + @Test + public void testOnePartitionOnEachWorker() + { + CounterSnapshotsTree counterSnapshots = new CounterSnapshotsTree(); + ChannelCounters worker0 = createChannelCounters(new int[]{0}); + ChannelCounters worker1 = createChannelCounters(new int[]{1}); + ChannelCounters worker2 = createChannelCounters(new int[]{2}); + ChannelCounters worker3 = createChannelCounters(new int[]{4}); + + counterSnapshots.put(0, 0, new CounterSnapshots(ImmutableMap.of("output", worker0.snapshot()))); + counterSnapshots.put(0, 1, new CounterSnapshots(ImmutableMap.of("output", worker1.snapshot()))); + counterSnapshots.put(0, 2, new CounterSnapshots(ImmutableMap.of("output", worker2.snapshot()))); + counterSnapshots.put(0, 3, new CounterSnapshots(ImmutableMap.of("output", worker3.snapshot()))); + + MSQTaskReportPayload payload = new MSQTaskReportPayload(new MSQStatusReport( + TaskState.SUCCESS, + null, + new ArrayDeque<>(), + null, + 0, + 1, + 2, + null + ), MSQStagesReport.create( + MSQTaskReportTest.QUERY_DEFINITION, + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(0, 4), + ImmutableMap.of(0, 4) + ), counterSnapshots, null); + + Optional> pages = SqlStatementResourceHelper.populatePageList( + payload, + DurableStorageMSQDestination.instance() + ); + validatePages(pages.get(), createValidationMap(worker0, worker1, worker2)); + } + + + @Test + public void testCommonPartitionsOnEachWorker() + { + CounterSnapshotsTree counterSnapshots = new CounterSnapshotsTree(); + ChannelCounters worker0 = createChannelCounters(new int[]{0, 1, 2, 3, 8, 9}); + ChannelCounters worker1 = createChannelCounters(new int[]{1, 4, 12}); + ChannelCounters worker2 = createChannelCounters(new int[]{20}); + ChannelCounters worker3 = createChannelCounters(new int[]{2, 2, 5, 6, 7, 9, 15}); + + counterSnapshots.put(0, 0, new CounterSnapshots(ImmutableMap.of("output", worker0.snapshot()))); + counterSnapshots.put(0, 1, new CounterSnapshots(ImmutableMap.of("output", worker1.snapshot()))); + counterSnapshots.put(0, 2, new CounterSnapshots(ImmutableMap.of("output", worker2.snapshot()))); + counterSnapshots.put(0, 3, new CounterSnapshots(ImmutableMap.of("output", worker3.snapshot()))); + + MSQTaskReportPayload payload = new MSQTaskReportPayload(new MSQStatusReport( + TaskState.SUCCESS, + null, + new ArrayDeque<>(), + null, + 0, + 1, + 2, + null + ), MSQStagesReport.create( + MSQTaskReportTest.QUERY_DEFINITION, + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(0, 4), + ImmutableMap.of(0, 21) + ), counterSnapshots, null); + + Optional> pages = + SqlStatementResourceHelper.populatePageList(payload, DurableStorageMSQDestination.instance()); + validatePages(pages.get(), createValidationMap(worker0, worker1, worker2, worker3)); + } + + + @Test + public void testNullChannelCounters() + { + CounterSnapshotsTree counterSnapshots = new CounterSnapshotsTree(); + ChannelCounters worker0 = createChannelCounters(new int[0]); + ChannelCounters worker1 = createChannelCounters(new int[]{1, 4, 12}); + ChannelCounters worker2 = createChannelCounters(new int[]{20}); + ChannelCounters worker3 = createChannelCounters(new int[]{2, 2, 5, 6, 7, 9, 15}); + + counterSnapshots.put(0, 0, new CounterSnapshots(new HashMap<>())); + counterSnapshots.put(0, 1, new CounterSnapshots(ImmutableMap.of("output", worker1.snapshot()))); + counterSnapshots.put(0, 2, new CounterSnapshots(ImmutableMap.of("output", worker2.snapshot()))); + counterSnapshots.put(0, 3, new CounterSnapshots(ImmutableMap.of("output", worker3.snapshot()))); + + MSQTaskReportPayload payload = new MSQTaskReportPayload(new MSQStatusReport( + TaskState.SUCCESS, + null, + new ArrayDeque<>(), + null, + 0, + 1, + 2, + null + ), MSQStagesReport.create( + MSQTaskReportTest.QUERY_DEFINITION, + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(0, 4), + ImmutableMap.of(0, 21) + ), counterSnapshots, null); + + Optional> pages = SqlStatementResourceHelper.populatePageList( + payload, + DurableStorageMSQDestination.instance() + ); + validatePages(pages.get(), createValidationMap(worker0, worker1, worker2, worker3)); + } + + + @Test + public void testConsecutivePartitionsOnEachWorker() + { + CounterSnapshotsTree counterSnapshots = new CounterSnapshotsTree(); + ChannelCounters worker0 = createChannelCounters(new int[]{0, 1, 2}); + ChannelCounters worker1 = createChannelCounters(new int[]{3, 4, 5}); + ChannelCounters worker2 = createChannelCounters(new int[]{6, 7, 8}); + ChannelCounters worker3 = createChannelCounters(new int[]{9, 10, 11, 12}); + + counterSnapshots.put(0, 0, new CounterSnapshots(ImmutableMap.of("output", worker0.snapshot()))); + counterSnapshots.put(0, 1, new CounterSnapshots(ImmutableMap.of("output", worker1.snapshot()))); + counterSnapshots.put(0, 2, new CounterSnapshots(ImmutableMap.of("output", worker2.snapshot()))); + counterSnapshots.put(0, 3, new CounterSnapshots(ImmutableMap.of("output", worker3.snapshot()))); + + MSQTaskReportPayload payload = new MSQTaskReportPayload(new MSQStatusReport( + TaskState.SUCCESS, + null, + new ArrayDeque<>(), + null, + 0, + 1, + 2, + null + ), MSQStagesReport.create( + MSQTaskReportTest.QUERY_DEFINITION, + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableMap.of(0, 4), + ImmutableMap.of(0, 13) + ), counterSnapshots, null); + + Optional> pages = SqlStatementResourceHelper.populatePageList( + payload, + DurableStorageMSQDestination.instance() + ); + validatePages(pages.get(), createValidationMap(worker0, worker1, worker2, worker3)); + } + + + private void validatePages( + List pageList, + Map>> partitionToWorkerToRowsBytes + ) + { + int currentPage = 0; + for (Map.Entry>> partitionWorker : partitionToWorkerToRowsBytes.entrySet()) { + for (Map.Entry> workerRowsBytes : partitionWorker.getValue().entrySet()) { + PageInformation pageInformation = pageList.get(currentPage); + Assert.assertEquals(currentPage, pageInformation.getId()); + Assert.assertEquals(workerRowsBytes.getValue().lhs, pageInformation.getNumRows()); + Assert.assertEquals(workerRowsBytes.getValue().rhs, pageInformation.getSizeInBytes()); + Assert.assertEquals(partitionWorker.getKey(), pageInformation.getPartition()); + Assert.assertEquals(workerRowsBytes.getKey(), pageInformation.getWorker()); + log.debug(pageInformation.toString()); + currentPage++; + } + } + Assert.assertEquals(currentPage, pageList.size()); + } + + private Map>> createValidationMap( + ChannelCounters... workers + ) + { + if (workers == null || workers.length == 0) { + return new HashMap<>(); + } else { + Map>> partitionToWorkerToRowsBytes = new TreeMap<>(); + for (int worker = 0; worker < workers.length; worker++) { + ChannelCounters.Snapshot workerCounter = workers[worker].snapshot(); + for (int partition = 0; workerCounter != null && partition < workerCounter.getRows().length; partition++) { + Map> workerMap = partitionToWorkerToRowsBytes.computeIfAbsent( + partition, + k -> new TreeMap<>() + ); + + if (workerCounter.getRows()[partition] != 0) { + workerMap.put( + worker, + new Pair<>( + workerCounter.getRows()[partition], + workerCounter.getBytes()[partition] + ) + ); + } + + } + } + return partitionToWorkerToRowsBytes; + } + } + + + private ChannelCounters createChannelCounters(int[] partitions) + { + if (partitions == null || partitions.length == 0) { + return new ChannelCounters(); + } + ChannelCounters channelCounters = new ChannelCounters(); + int prev = -1; + for (int current : partitions) { + if (prev > current) { + throw new IllegalArgumentException("Channel numbers should be in increasing order"); + } + channelCounters.addFrame(current, createFrame(current * 10 + 1, 100L)); + prev = current; + } + return channelCounters; + } + + + private Frame createFrame(int numRows, long numBytes) + { + Frame frame = EasyMock.mock(Frame.class); + EasyMock.expect(frame.numRows()).andReturn(numRows).anyTimes(); + EasyMock.expect(frame.numBytes()).andReturn(numBytes).anyTimes(); + EasyMock.replay(frame); + return frame; + } +}