Skip to content

Commit

Permalink
Add MSQ query context maxNumSegments (#16637)
Browse files Browse the repository at this point in the history
* Add MSQ query context maxNumSegments.

- Default is MAX_INT (unbounded).
- When set and if a time chunk contains more number of segments than set in the
  query context, the MSQ task will fail with TooManySegments fault.

* Fixup hashCode().

* Rename and checkpoint.

* Add some insert and replace happy and sad path tests.

* Update error msg.

* Commentary

* Adjust the default to be null (meaning no max bound on number of segments).

Also fix formatter.

* Fix CodeQL warnings and minor cleanup.

* Assert on maxNumSegments tuning config.

* Minor test cleanup.

* Use null default for the MultiStageQueryContext as well

* Review feedback

* Review feedback

* Move logic to common function getPartitionsByBucket shared by INSERT and REPLACE.

* Rename to validateNumSegmentsPerBucketOrThrow() for consistency.

* Add segmentGranularity to error message.
  • Loading branch information
abhishekrb19 authored Jun 26, 2024
1 parent b772277 commit 82117e8
Show file tree
Hide file tree
Showing 12 changed files with 400 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
import org.apache.druid.msq.indexing.error.MSQWarningReportLimiterPublisher;
import org.apache.druid.msq.indexing.error.QueryNotSupportedFault;
import org.apache.druid.msq.indexing.error.TooManyBucketsFault;
import org.apache.druid.msq.indexing.error.TooManySegmentsInTimeChunkFault;
import org.apache.druid.msq.indexing.error.TooManyWarningsFault;
import org.apache.druid.msq.indexing.error.UnknownFault;
import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault;
Expand Down Expand Up @@ -962,6 +963,14 @@ private List<SegmentIdWithShardSpec> generateSegmentIdsWithShardSpecsForAppend(

final Granularity segmentGranularity = destination.getSegmentGranularity();

// Compute & validate partitions by bucket (time chunk) if there is a maximum number of segments to be enforced per time chunk
if (querySpec.getTuningConfig().getMaxNumSegments() != null) {
final Map<DateTime, List<Pair<Integer, ClusterByPartition>>> partitionsByBucket =
getPartitionsByBucket(partitionBoundaries, segmentGranularity, keyReader);

validateNumSegmentsPerBucketOrThrow(partitionsByBucket, segmentGranularity);
}

String previousSegmentId = null;

segmentReport = new MSQSegmentReport(
Expand Down Expand Up @@ -1029,6 +1038,43 @@ private List<SegmentIdWithShardSpec> generateSegmentIdsWithShardSpecsForAppend(
return retVal;
}

/**
* Return partition ranges by bucket (time chunk).
*/
private Map<DateTime, List<Pair<Integer, ClusterByPartition>>> getPartitionsByBucket(
final ClusterByPartitions partitionBoundaries,
final Granularity segmentGranularity,
final RowKeyReader keyReader
)
{
final Map<DateTime, List<Pair<Integer, ClusterByPartition>>> partitionsByBucket = new HashMap<>();
for (int i = 0; i < partitionBoundaries.ranges().size(); i++) {
final ClusterByPartition partitionBoundary = partitionBoundaries.ranges().get(i);
final DateTime bucketDateTime = getBucketDateTime(partitionBoundary, segmentGranularity, keyReader);
partitionsByBucket.computeIfAbsent(bucketDateTime, ignored -> new ArrayList<>())
.add(Pair.of(i, partitionBoundary));
}
return partitionsByBucket;
}

private void validateNumSegmentsPerBucketOrThrow(
final Map<DateTime, List<Pair<Integer, ClusterByPartition>>> partitionsByBucket,
final Granularity segmentGranularity
)
{
final Integer maxNumSegments = querySpec.getTuningConfig().getMaxNumSegments();
if (maxNumSegments == null) {
// Return early because a null value indicates no maximum, i.e., a time chunk can have any number of segments.
return;
}
for (final Map.Entry<DateTime, List<Pair<Integer, ClusterByPartition>>> bucketEntry : partitionsByBucket.entrySet()) {
final int numSegmentsInTimeChunk = bucketEntry.getValue().size();
if (numSegmentsInTimeChunk > maxNumSegments) {
throw new MSQException(new TooManySegmentsInTimeChunkFault(bucketEntry.getKey(), numSegmentsInTimeChunk, maxNumSegments, segmentGranularity));
}
}
}

/**
* Used by {@link #generateSegmentIdsWithShardSpecs}.
*
Expand Down Expand Up @@ -1072,13 +1118,11 @@ private List<SegmentIdWithShardSpec> generateSegmentIdsWithShardSpecsForReplace(
}

// Group partition ranges by bucket (time chunk), so we can generate shardSpecs for each bucket independently.
final Map<DateTime, List<Pair<Integer, ClusterByPartition>>> partitionsByBucket = new HashMap<>();
for (int i = 0; i < partitionBoundaries.ranges().size(); i++) {
ClusterByPartition partitionBoundary = partitionBoundaries.ranges().get(i);
final DateTime bucketDateTime = getBucketDateTime(partitionBoundary, segmentGranularity, keyReader);
partitionsByBucket.computeIfAbsent(bucketDateTime, ignored -> new ArrayList<>())
.add(Pair.of(i, partitionBoundary));
}
final Map<DateTime, List<Pair<Integer, ClusterByPartition>>> partitionsByBucket =
getPartitionsByBucket(partitionBoundaries, segmentGranularity, keyReader);

// Validate the buckets.
validateNumSegmentsPerBucketOrThrow(partitionsByBucket, segmentGranularity);

// Process buckets (time chunks) one at a time.
for (final Map.Entry<DateTime, List<Pair<Integer, ClusterByPartition>>> bucketEntry : partitionsByBucket.entrySet()) {
Expand All @@ -1090,6 +1134,7 @@ private List<SegmentIdWithShardSpec> generateSegmentIdsWithShardSpecsForReplace(
}

final List<Pair<Integer, ClusterByPartition>> ranges = bucketEntry.getValue();

String version = null;

final List<TaskLock> locks = context.taskActionClient().submit(new LockListAction());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import org.apache.druid.msq.indexing.error.TooManyInputFilesFault;
import org.apache.druid.msq.indexing.error.TooManyPartitionsFault;
import org.apache.druid.msq.indexing.error.TooManyRowsWithSameKeyFault;
import org.apache.druid.msq.indexing.error.TooManySegmentsInTimeChunkFault;
import org.apache.druid.msq.indexing.error.TooManyWarningsFault;
import org.apache.druid.msq.indexing.error.TooManyWorkersFault;
import org.apache.druid.msq.indexing.error.UnknownFault;
Expand Down Expand Up @@ -126,6 +127,7 @@ public class MSQIndexingModule implements DruidModule
TooManyInputFilesFault.class,
TooManyPartitionsFault.class,
TooManyRowsWithSameKeyFault.class,
TooManySegmentsInTimeChunkFault.class,
TooManyWarningsFault.class,
TooManyWorkersFault.class,
TooManyAttemptsForJob.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,30 @@ public class MSQTuningConfig
@Nullable
private final Integer rowsPerSegment;

@Nullable
private final Integer maxNumSegments;

@Nullable
private final IndexSpec indexSpec;

public MSQTuningConfig(
@JsonProperty("maxNumWorkers") @Nullable final Integer maxNumWorkers,
@JsonProperty("maxRowsInMemory") @Nullable final Integer maxRowsInMemory,
@JsonProperty("rowsPerSegment") @Nullable final Integer rowsPerSegment,
@JsonProperty("maxNumSegments") @Nullable final Integer maxNumSegments,
@JsonProperty("indexSpec") @Nullable final IndexSpec indexSpec
)
{
this.maxNumWorkers = maxNumWorkers;
this.maxRowsInMemory = maxRowsInMemory;
this.rowsPerSegment = rowsPerSegment;
this.maxNumSegments = maxNumSegments;
this.indexSpec = indexSpec;
}

public static MSQTuningConfig defaultConfig()
{
return new MSQTuningConfig(null, null, null, null);
return new MSQTuningConfig(null, null, null, null, null);
}

@JsonProperty("maxNumWorkers")
Expand All @@ -98,6 +103,13 @@ Integer getRowsPerSegmentForSerialization()
return rowsPerSegment;
}

@JsonProperty("maxNumSegments")
@JsonInclude(JsonInclude.Include.NON_NULL)
Integer getMaxNumSegmentsForSerialization()
{
return maxNumSegments;
}

@JsonProperty("indexSpec")
@JsonInclude(JsonInclude.Include.NON_NULL)
IndexSpec getIndexSpecForSerialization()
Expand All @@ -120,6 +132,12 @@ public int getRowsPerSegment()
return rowsPerSegment != null ? rowsPerSegment : PartitionsSpec.DEFAULT_MAX_ROWS_PER_SEGMENT;
}

@Nullable
public Integer getMaxNumSegments()
{
return maxNumSegments;
}

public IndexSpec getIndexSpec()
{
return indexSpec != null ? indexSpec : IndexSpec.DEFAULT;
Expand All @@ -138,13 +156,14 @@ public boolean equals(Object o)
return Objects.equals(maxNumWorkers, that.maxNumWorkers)
&& Objects.equals(maxRowsInMemory, that.maxRowsInMemory)
&& Objects.equals(rowsPerSegment, that.rowsPerSegment)
&& Objects.equals(maxNumSegments, that.maxNumSegments)
&& Objects.equals(indexSpec, that.indexSpec);
}

@Override
public int hashCode()
{
return Objects.hash(maxNumWorkers, maxRowsInMemory, rowsPerSegment, indexSpec);
return Objects.hash(maxNumWorkers, maxRowsInMemory, rowsPerSegment, maxNumSegments, indexSpec);
}

@Override
Expand All @@ -154,6 +173,7 @@ public String toString()
"maxNumWorkers=" + maxNumWorkers +
", maxRowsInMemory=" + maxRowsInMemory +
", rowsPerSegment=" + rowsPerSegment +
", maxNumSegments=" + maxNumSegments +
", indexSpec=" + indexSpec +
'}';
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.msq.indexing.error;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonTypeName;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.java.util.common.granularity.GranularityType;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.joda.time.DateTime;

import java.util.Objects;

@JsonTypeName(TooManySegmentsInTimeChunkFault.CODE)
public class TooManySegmentsInTimeChunkFault extends BaseMSQFault
{
public static final String CODE = "TooManySegmentsInTimeChunk";

private final DateTime timeChunk;
private final int numSegments;
private final int maxNumSegments;
private final Granularity segmentGranularity;

@JsonCreator
public TooManySegmentsInTimeChunkFault(
@JsonProperty("timeChunk") final DateTime timeChunk,
@JsonProperty("numSegments") final int numSegments,
@JsonProperty("maxNumSegments") final int maxNumSegments,
@JsonProperty("segmentGranularity") final Granularity segmentGranularity
)
{
super(
CODE,
"Too many segments requested to be generated in time chunk[%s] with granularity[%s]"
+ " (requested = [%,d], maximum = [%,d]). Please try breaking up your query or change the maximum using"
+ " the query context parameter[%s].",
timeChunk,
convertToGranularityString(segmentGranularity),
numSegments,
maxNumSegments,
MultiStageQueryContext.CTX_MAX_NUM_SEGMENTS
);
this.timeChunk = timeChunk;
this.numSegments = numSegments;
this.maxNumSegments = maxNumSegments;
this.segmentGranularity = segmentGranularity;
}

/**
* Convert the given granularity to a more user-friendly granularity string, when possible.
*/
private static String convertToGranularityString(final Granularity granularity)
{
// If it's a "standard" granularity, we get a nicer string from the GranularityType enum. For any other
// granularity, we just fall back to the toString(). See GranularityType#isStandard().
for (GranularityType value : GranularityType.values()) {
if (value.getDefaultGranularity().equals(granularity)) {
return value.name();
}
}
return granularity.toString();
}

@JsonProperty
public DateTime getTimeChunk()
{
return timeChunk;
}

@JsonProperty
public int getNumSegments()
{
return numSegments;
}

@JsonProperty
public int getMaxNumSegments()
{
return maxNumSegments;
}

@JsonProperty
public Granularity getSegmentGranularity()
{
return segmentGranularity;
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
if (!super.equals(o)) {
return false;
}
TooManySegmentsInTimeChunkFault that = (TooManySegmentsInTimeChunkFault) o;
return numSegments == that.numSegments
&& maxNumSegments == that.maxNumSegments
&& Objects.equals(timeChunk, that.timeChunk)
&& Objects.equals(segmentGranularity, that.segmentGranularity);
}

@Override
public int hashCode()
{
return Objects.hash(super.hashCode(), timeChunk, numSegments, maxNumSegments, segmentGranularity);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ public QueryResponse<Object[]> runQuery(final DruidQuery druidQuery)
final int maxNumWorkers = maxNumTasks - 1;
final int rowsPerSegment = MultiStageQueryContext.getRowsPerSegment(sqlQueryContext);
final int maxRowsInMemory = MultiStageQueryContext.getRowsInMemory(sqlQueryContext);
final Integer maxNumSegments = MultiStageQueryContext.getMaxNumSegments(sqlQueryContext);
final IndexSpec indexSpec = MultiStageQueryContext.getIndexSpec(sqlQueryContext, jsonMapper);
final boolean finalizeAggregations = MultiStageQueryContext.isFinalizeAggregations(sqlQueryContext);

Expand Down Expand Up @@ -279,7 +280,7 @@ public QueryResponse<Object[]> runQuery(final DruidQuery druidQuery)
.columnMappings(new ColumnMappings(columnMappings))
.destination(destination)
.assignmentStrategy(MultiStageQueryContext.getAssignmentStrategy(sqlQueryContext))
.tuningConfig(new MSQTuningConfig(maxNumWorkers, maxRowsInMemory, rowsPerSegment, indexSpec))
.tuningConfig(new MSQTuningConfig(maxNumWorkers, maxRowsInMemory, rowsPerSegment, maxNumSegments, indexSpec))
.build();

MSQTaskQueryMakerUtils.validateRealtimeReindex(querySpec);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ public class MultiStageQueryContext

public static final String CTX_IS_REINDEX = "isReindex";

public static final String CTX_MAX_NUM_SEGMENTS = "maxNumSegments";

/**
* Controls sort order within segments. Normally, this is the same as the overall order of the query (from the
* CLUSTERED BY clause) but it can be overridden.
Expand Down Expand Up @@ -324,6 +326,12 @@ public static int getRowsInMemory(final QueryContext queryContext)
return queryContext.getInt(CTX_ROWS_IN_MEMORY, DEFAULT_ROWS_IN_MEMORY);
}

public static Integer getMaxNumSegments(final QueryContext queryContext)
{
// The default is null, if the context is not set.
return queryContext.getInt(CTX_MAX_NUM_SEGMENTS);
}

public static List<String> getSortOrder(final QueryContext queryContext)
{
return decodeList(CTX_SORT_ORDER, queryContext.getString(CTX_SORT_ORDER));
Expand Down
Loading

0 comments on commit 82117e8

Please sign in to comment.