Skip to content

Commit

Permalink
MSQ window functions: Fix query correctness issues when using multipl…
Browse files Browse the repository at this point in the history
…e workers (#16804)

This PR fixes query correctness issues for MSQ window functions when using more than 1 worker (that is, maxNumTasks > 2).

Currently, we were keeping the shuffle spec of the previous stage when we didn't have any partition columns for window stage. This PR changes it to override the shuffle spec of the previous stage to MixShuffleSpec (if we have a window function with empty over clause) so that the window stage gets a single partition to work on.

A test has been added for a query which returned incorrect results prior to this change when using more than 1 workers.
  • Loading branch information
Akshat-Jain authored Aug 6, 2024
1 parent ed6b547 commit c3aa033
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageDefinitionBuilder;
import org.apache.druid.msq.querykit.common.SortMergeJoinFrameProcessorFactory;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.DataSource;
import org.apache.druid.query.FilteredDataSource;
import org.apache.druid.query.InlineDataSource;
Expand Down Expand Up @@ -424,21 +423,11 @@ private static DataSourcePlan forQuery(
@Nullable final QueryContext parentContext
)
{
// check if parentContext has a window operator
final Map<String, Object> windowShuffleMap = new HashMap<>();
if (parentContext != null && parentContext.containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL)) {
windowShuffleMap.put(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL, parentContext.get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL));
}
final QueryDefinition subQueryDef = queryKit.makeQueryDefinition(
queryId,
// Subqueries ignore SQL_INSERT_SEGMENT_GRANULARITY, even if set in the context. It's only used for the
// outermost query, and setting it for the subquery makes us erroneously add bucketing where it doesn't belong.
windowShuffleMap.isEmpty()
? dataSource.getQuery()
.withOverriddenContext(CONTEXT_MAP_NO_SEGMENT_GRANULARITY)
: dataSource.getQuery()
.withOverriddenContext(CONTEXT_MAP_NO_SEGMENT_GRANULARITY)
.withOverriddenContext(windowShuffleMap),
dataSource.getQuery().withOverriddenContext(CONTEXT_MAP_NO_SEGMENT_GRANULARITY),
queryKit,
ShuffleSpecFactories.globalSortWithMaxPartitionCount(maxWorkerCount),
maxWorkerCount,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.apache.druid.msq.querykit;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.frame.key.ClusterBy;
import org.apache.druid.frame.key.KeyColumn;
import org.apache.druid.frame.key.KeyOrder;
Expand Down Expand Up @@ -88,17 +87,6 @@ public QueryDefinition makeQueryDefinition(
List<List<OperatorFactory>> operatorList = getOperatorListFromQuery(originalQuery);
log.info("Created operatorList with operator factories: [%s]", operatorList);

ShuffleSpec nextShuffleSpec = findShuffleSpecForNextWindow(operatorList.get(0), maxWorkerCount);
// add this shuffle spec to the last stage of the inner query

final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder(queryId);
if (nextShuffleSpec != null) {
final ClusterBy windowClusterBy = nextShuffleSpec.clusterBy();
originalQuery = (WindowOperatorQuery) originalQuery.withOverriddenContext(ImmutableMap.of(
MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL,
windowClusterBy
));
}
final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource(
queryKit,
queryId,
Expand All @@ -112,7 +100,8 @@ public QueryDefinition makeQueryDefinition(
false
);

dataSourcePlan.getSubQueryDefBuilder().ifPresent(queryDefBuilder::addAll);
ShuffleSpec nextShuffleSpec = findShuffleSpecForNextWindow(operatorList.get(0), maxWorkerCount);
final QueryDefinitionBuilder queryDefBuilder = makeQueryDefinitionBuilder(queryId, dataSourcePlan, nextShuffleSpec);

final int firstStageNumber = Math.max(minStageNumber, queryDefBuilder.getNextStageNumber());
final WindowOperatorQuery queryToRun = (WindowOperatorQuery) originalQuery.withDataSource(dataSourcePlan.getNewDataSource());
Expand Down Expand Up @@ -309,12 +298,16 @@ private ShuffleSpec findShuffleSpecForNextWindow(List<OperatorFactory> operatorF
}
}

if (partition == null || partition.getPartitionColumns().isEmpty()) {
if (partition == null) {
// If operatorFactories doesn't have any partitioning factory, then we should keep the shuffle spec from previous stage.
// This indicates that we already have the data partitioned correctly, and hence we don't need to do any shuffling.
return null;
}

if (partition.getPartitionColumns().isEmpty()) {
return MixShuffleSpec.instance();
}

List<KeyColumn> keyColsOfWindow = new ArrayList<>();
for (String partitionColumn : partition.getPartitionColumns()) {
KeyColumn kc;
Expand All @@ -328,4 +321,29 @@ private ShuffleSpec findShuffleSpecForNextWindow(List<OperatorFactory> operatorF

return new HashShuffleSpec(new ClusterBy(keyColsOfWindow, 0), maxWorkerCount);
}

/**
* Override the shuffle spec of the last stage based on the shuffling required by the first window stage.
* @param queryId
* @param dataSourcePlan
* @param shuffleSpec
* @return
*/
private QueryDefinitionBuilder makeQueryDefinitionBuilder(String queryId, DataSourcePlan dataSourcePlan, ShuffleSpec shuffleSpec)
{
final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder(queryId);
int previousStageNumber = dataSourcePlan.getSubQueryDefBuilder().get().build().getFinalStageDefinition().getStageNumber();
for (final StageDefinition stageDef : dataSourcePlan.getSubQueryDefBuilder().get().build().getStageDefinitions()) {
if (stageDef.getStageNumber() == previousStageNumber) {
RowSignature rowSignature = QueryKitUtils.sortableSignature(
stageDef.getSignature(),
shuffleSpec.clusterBy().getColumns()
);
queryDefBuilder.add(StageDefinition.builder(stageDef).shuffleSpec(shuffleSpec).signature(rowSignature));
} else {
queryDefBuilder.add(StageDefinition.builder(stageDef));
}
}
return queryDefBuilder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.msq.input.stage.StageInputSpec;
import org.apache.druid.msq.kernel.HashShuffleSpec;
import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
import org.apache.druid.msq.kernel.ShuffleSpec;
Expand All @@ -39,7 +38,6 @@
import org.apache.druid.msq.querykit.ShuffleSpecFactories;
import org.apache.druid.msq.querykit.ShuffleSpecFactory;
import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.DimensionComparisonUtils;
import org.apache.druid.query.Query;
import org.apache.druid.query.dimension.DimensionSpec;
Expand Down Expand Up @@ -168,104 +166,40 @@ public QueryDefinition makeQueryDefinition(
partitionBoost
);

final ShuffleSpec nextShuffleWindowSpec = getShuffleSpecForNextWindow(originalQuery, maxWorkerCount);
queryDefBuilder.add(
StageDefinition.builder(firstStageNumber + 1)
.inputs(new StageInputSpec(firstStageNumber))
.signature(resultSignature)
.maxWorkerCount(maxWorkerCount)
.shuffleSpec(
shuffleSpecFactoryPostAggregation != null
? shuffleSpecFactoryPostAggregation.build(resultClusterBy, false)
: null
)
.processorFactory(new GroupByPostShuffleFrameProcessorFactory(queryToRun))
);

if (nextShuffleWindowSpec == null) {
if (doLimitOrOffset) {
final ShuffleSpec finalShuffleSpec = resultShuffleSpecFactory.build(resultClusterBy, false);
final DefaultLimitSpec limitSpec = (DefaultLimitSpec) queryToRun.getLimitSpec();
queryDefBuilder.add(
StageDefinition.builder(firstStageNumber + 1)
.inputs(new StageInputSpec(firstStageNumber))
StageDefinition.builder(firstStageNumber + 2)
.inputs(new StageInputSpec(firstStageNumber + 1))
.signature(resultSignature)
.maxWorkerCount(maxWorkerCount)
.shuffleSpec(
shuffleSpecFactoryPostAggregation != null
? shuffleSpecFactoryPostAggregation.build(resultClusterBy, false)
: null
)
.processorFactory(new GroupByPostShuffleFrameProcessorFactory(queryToRun))
);

if (doLimitOrOffset) {
final ShuffleSpec finalShuffleSpec = resultShuffleSpecFactory.build(resultClusterBy, false);
final DefaultLimitSpec limitSpec = (DefaultLimitSpec) queryToRun.getLimitSpec();
queryDefBuilder.add(
StageDefinition.builder(firstStageNumber + 2)
.inputs(new StageInputSpec(firstStageNumber + 1))
.signature(resultSignature)
.maxWorkerCount(1)
.shuffleSpec(finalShuffleSpec)
.processorFactory(
new OffsetLimitFrameProcessorFactory(
limitSpec.getOffset(),
limitSpec.isLimited() ? (long) limitSpec.getLimit() : null
)
)
);
}
} else {
final RowSignature stageSignature;
// sort the signature to make sure the prefix is aligned
stageSignature = QueryKitUtils.sortableSignature(
resultSignature,
nextShuffleWindowSpec.clusterBy().getColumns()
);


queryDefBuilder.add(
StageDefinition.builder(firstStageNumber + 1)
.inputs(new StageInputSpec(firstStageNumber))
.signature(stageSignature)
.maxWorkerCount(maxWorkerCount)
.shuffleSpec(doLimitOrOffset ? (shuffleSpecFactoryPostAggregation != null
? shuffleSpecFactoryPostAggregation.build(
resultClusterBy,
false
.maxWorkerCount(1)
.shuffleSpec(finalShuffleSpec)
.processorFactory(
new OffsetLimitFrameProcessorFactory(
limitSpec.getOffset(),
limitSpec.isLimited() ? (long) limitSpec.getLimit() : null
)
)
: null) : nextShuffleWindowSpec)
.processorFactory(new GroupByPostShuffleFrameProcessorFactory(queryToRun))
);
if (doLimitOrOffset) {
final DefaultLimitSpec limitSpec = (DefaultLimitSpec) queryToRun.getLimitSpec();
final ShuffleSpec finalShuffleSpec = resultShuffleSpecFactory.build(resultClusterBy, false);
queryDefBuilder.add(
StageDefinition.builder(firstStageNumber + 2)
.inputs(new StageInputSpec(firstStageNumber + 1))
.signature(resultSignature)
.maxWorkerCount(1)
.shuffleSpec(finalShuffleSpec)
.processorFactory(
new OffsetLimitFrameProcessorFactory(
limitSpec.getOffset(),
limitSpec.isLimited() ? (long) limitSpec.getLimit() : null
)
)
);
}
}

return queryDefBuilder.build();
}

/**
* @param originalQuery which has the context for the next shuffle if that's present in the next window
* @param maxWorkerCount max worker count
* @return shuffle spec without partition boosting for next stage, null if there is no partition by for next window
*/
private ShuffleSpec getShuffleSpecForNextWindow(GroupByQuery originalQuery, int maxWorkerCount)
{
final ShuffleSpec nextShuffleWindowSpec;
if (originalQuery.getContext().containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL)) {
final ClusterBy windowClusterBy = (ClusterBy) originalQuery.getContext()
.get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL);
nextShuffleWindowSpec = new HashShuffleSpec(
windowClusterBy,
maxWorkerCount
);
} else {
nextShuffleWindowSpec = null;
}
return nextShuffleWindowSpec;
}

/**
* Intermediate signature of a particular {@link GroupByQuery}. Does not include post-aggregators, and all
* aggregations are nonfinalized.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.apache.druid.msq.querykit.ShuffleSpecFactories;
import org.apache.druid.msq.querykit.ShuffleSpecFactory;
import org.apache.druid.msq.querykit.common.OffsetLimitFrameProcessorFactory;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.Query;
import org.apache.druid.query.scan.ScanQuery;
import org.apache.druid.segment.column.ColumnType;
Expand Down Expand Up @@ -129,26 +128,8 @@ public QueryDefinition makeQueryDefinition(
);
}

// Update partition by of next window
final RowSignature signatureSoFar = signatureBuilder.build();
boolean addShuffle = true;
if (originalQuery.getContext().containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL)) {
final ClusterBy windowClusterBy = (ClusterBy) originalQuery.getContext()
.get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL);
for (KeyColumn c : windowClusterBy.getColumns()) {
if (!signatureSoFar.contains(c.columnName())) {
addShuffle = false;
break;
}
}
if (addShuffle) {
clusterByColumns.addAll(windowClusterBy.getColumns());
}
} else {
// Add partition boosting column.
clusterByColumns.add(new KeyColumn(QueryKitUtils.PARTITION_BOOST_COLUMN, KeyOrder.ASCENDING));
signatureBuilder.add(QueryKitUtils.PARTITION_BOOST_COLUMN, ColumnType.LONG);
}
clusterByColumns.add(new KeyColumn(QueryKitUtils.PARTITION_BOOST_COLUMN, KeyOrder.ASCENDING));
signatureBuilder.add(QueryKitUtils.PARTITION_BOOST_COLUMN, ColumnType.LONG);

final ClusterBy clusterBy =
QueryKitUtils.clusterByWithSegmentGranularity(new ClusterBy(clusterByColumns, 0), segmentGranularity);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,6 @@ public class MultiStageQueryContext
public static final String CTX_ARRAY_INGEST_MODE = "arrayIngestMode";
public static final ArrayIngestMode DEFAULT_ARRAY_INGEST_MODE = ArrayIngestMode.ARRAY;

public static final String NEXT_WINDOW_SHUFFLE_COL = "__windowShuffleCol";

public static final String MAX_ROWS_MATERIALIZED_IN_WINDOW = "maxRowsMaterializedInWindow";

public static final String CTX_SKIP_TYPE_VERIFICATION = "skipTypeVerification";
Expand Down
Loading

0 comments on commit c3aa033

Please sign in to comment.