Skip to content

Commit

Permalink
MSQ window functions: Fix query correctness issues when using multipl…
Browse files Browse the repository at this point in the history
…e workers
  • Loading branch information
Akshat-Jain committed Jul 25, 2024
1 parent 7e3fab5 commit 6b4b7eb
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ private static DataSourcePlan forQuery(
{
// check if parentContext has a window operator
final Map<String, Object> windowShuffleMap = new HashMap<>();
if (parentContext != null && parentContext.containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL)) {
windowShuffleMap.put(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL, parentContext.get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL));
if (parentContext != null && parentContext.containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_SPEC)) {
windowShuffleMap.put(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_SPEC, parentContext.get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_SPEC));
}
final QueryDefinition subQueryDef = queryKit.makeQueryDefinition(
queryId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,9 @@ public QueryDefinition makeQueryDefinition(

final QueryDefinitionBuilder queryDefBuilder = QueryDefinition.builder(queryId);
if (nextShuffleSpec != null) {
final ClusterBy windowClusterBy = nextShuffleSpec.clusterBy();
originalQuery = (WindowOperatorQuery) originalQuery.withOverriddenContext(ImmutableMap.of(
MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL,
windowClusterBy
MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_SPEC,
nextShuffleSpec
));
}
final DataSourcePlan dataSourcePlan = DataSourcePlan.forDataSource(
Expand Down Expand Up @@ -309,12 +308,16 @@ private ShuffleSpec findShuffleSpecForNextWindow(List<OperatorFactory> operatorF
}
}

if (partition == null || partition.getPartitionColumns().isEmpty()) {
if (partition == null) {
// If operatorFactories doesn't have any partitioning factory, then we should keep the shuffle spec from previous stage.
// This indicates that we already have the data partitioned correctly, and hence we don't need to do any shuffling.
return null;
}

if (partition.getPartitionColumns().isEmpty()) {
return MixShuffleSpec.instance();
}

List<KeyColumn> keyColsOfWindow = new ArrayList<>();
for (String partitionColumn : partition.getPartitionColumns()) {
KeyColumn kc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.granularity.Granularity;
import org.apache.druid.msq.input.stage.StageInputSpec;
import org.apache.druid.msq.kernel.HashShuffleSpec;
import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.QueryDefinitionBuilder;
import org.apache.druid.msq.kernel.ShuffleSpec;
Expand Down Expand Up @@ -252,18 +251,10 @@ public QueryDefinition makeQueryDefinition(
*/
private ShuffleSpec getShuffleSpecForNextWindow(GroupByQuery originalQuery, int maxWorkerCount)
{
final ShuffleSpec nextShuffleWindowSpec;
if (originalQuery.getContext().containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL)) {
final ClusterBy windowClusterBy = (ClusterBy) originalQuery.getContext()
.get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL);
nextShuffleWindowSpec = new HashShuffleSpec(
windowClusterBy,
maxWorkerCount
);
} else {
nextShuffleWindowSpec = null;
if (originalQuery.getContext().containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_SPEC)) {
return (ShuffleSpec) originalQuery.getContext().get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_SPEC);
}
return nextShuffleWindowSpec;
return null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,22 @@ public QueryDefinition makeQueryDefinition(
// Update partition by of next window
final RowSignature signatureSoFar = signatureBuilder.build();
boolean addShuffle = true;
if (originalQuery.getContext().containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL)) {
final ClusterBy windowClusterBy = (ClusterBy) originalQuery.getContext()
.get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_COL);
for (KeyColumn c : windowClusterBy.getColumns()) {
if (!signatureSoFar.contains(c.columnName())) {
addShuffle = false;
break;
boolean windowHasEmptyOver = false;
if (originalQuery.getContext().containsKey(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_SPEC)) {
if (originalQuery.getContext().get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_SPEC) instanceof MixShuffleSpec) {
windowHasEmptyOver = true;
} else {
final ShuffleSpec shuffleSpec = (ShuffleSpec) originalQuery.getContext()
.get(MultiStageQueryContext.NEXT_WINDOW_SHUFFLE_SPEC);
for (KeyColumn c : shuffleSpec.clusterBy().getColumns()) {
if (!signatureSoFar.contains(c.columnName())) {
addShuffle = false;
break;
}
}
if (addShuffle) {
clusterByColumns.addAll(shuffleSpec.clusterBy().getColumns());
}
}
if (addShuffle) {
clusterByColumns.addAll(windowClusterBy.getColumns());
}
} else {
// Add partition boosting column.
Expand Down Expand Up @@ -178,6 +183,11 @@ public QueryDefinition makeQueryDefinition(
}
}

// If window has an empty over, we want a single worker to process entire data for window function evaluation.
if (windowHasEmptyOver) {
scanShuffleSpec = MixShuffleSpec.instance();
}

queryDefBuilder.add(
StageDefinition.builder(Math.max(minStageNumber, queryDefBuilder.getNextStageNumber()))
.inputs(dataSourcePlan.getInputSpecs())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ public class MultiStageQueryContext
public static final String CTX_ARRAY_INGEST_MODE = "arrayIngestMode";
public static final ArrayIngestMode DEFAULT_ARRAY_INGEST_MODE = ArrayIngestMode.MVD;

public static final String NEXT_WINDOW_SHUFFLE_COL = "__windowShuffleCol";
public static final String NEXT_WINDOW_SHUFFLE_SPEC = "__windowShuffleSpec";

public static final String MAX_ROWS_MATERIALIZED_IN_WINDOW = "maxRowsMaterializedInWindow";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
import org.apache.druid.query.TableDataSource;
import org.apache.druid.query.UnnestDataSource;
import org.apache.druid.query.aggregation.AggregatorFactory;
import org.apache.druid.query.aggregation.CountAggregatorFactory;
import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
import org.apache.druid.query.aggregation.FilteredAggregatorFactory;
import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
import org.apache.druid.query.dimension.DefaultDimensionSpec;
import org.apache.druid.query.groupby.GroupByQuery;
Expand All @@ -48,6 +50,7 @@
import org.apache.druid.query.operator.window.WindowFrame;
import org.apache.druid.query.operator.window.WindowFramedAggregateProcessor;
import org.apache.druid.query.operator.window.WindowOperatorFactory;
import org.apache.druid.query.operator.window.ranking.WindowRowNumberProcessor;
import org.apache.druid.query.scan.ScanQuery;
import org.apache.druid.query.spec.LegacySegmentSpec;
import org.apache.druid.segment.column.ColumnType;
Expand All @@ -65,6 +68,8 @@

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -1842,7 +1847,7 @@ public void testSelectWithWikipediaEmptyOverWithCustomContext(String contextName
.setSql(
"select cityName, added, SUM(added) OVER () cc from wikipedia")
.setQueryContext(customContext)
.setExpectedMSQFault(new TooManyRowsInAWindowFault(15676, 200))
.setExpectedMSQFault(new TooManyRowsInAWindowFault(26022, 200))
.verifyResults();
}

Expand Down Expand Up @@ -2048,4 +2053,141 @@ public void testReplaceGroupByOnWikipedia(String contextName, Map<String, Object
.setExpectedSegment(ImmutableSet.of(SegmentId.of("foo1", Intervals.ETERNITY, "test", 0)))
.verifyResults();
}

@MethodSource("data")
@ParameterizedTest(name = "{index}:with context {0}")
public void testWindowOnMixOfEmptyAndNonEmptyOverWithMultipleWorkers(String contextName, Map<String, Object> context)

Check notice

Code scanning / CodeQL

Useless parameter Note test

The parameter 'contextName' is never used.
{
final Map<String, Object> multipleWorkerContext = new HashMap<>(context);
multipleWorkerContext.put(MultiStageQueryContext.CTX_MAX_NUM_TASKS, 5);

final RowSignature rowSignature = RowSignature.builder()
.add("countryName", ColumnType.STRING)
.add("cityName", ColumnType.STRING)
.add("channel", ColumnType.STRING)
.add("c1", ColumnType.LONG)
.add("c2", ColumnType.LONG)
.build();

final Map<String, Object> contextWithRowSignature =
ImmutableMap.<String, Object>builder()
.putAll(multipleWorkerContext)
.put(
DruidQuery.CTX_SCAN_SIGNATURE,
"[{\"name\":\"d0\",\"type\":\"STRING\"},{\"name\":\"d1\",\"type\":\"STRING\"},{\"name\":\"d2\",\"type\":\"STRING\"},{\"name\":\"w0\",\"type\":\"LONG\"},{\"name\":\"w1\",\"type\":\"LONG\"}]"
)
.build();

final GroupByQuery groupByQuery = GroupByQuery.builder()
.setDataSource(CalciteTests.WIKIPEDIA)
.setInterval(querySegmentSpec(Filtration
.eternity()))
.setGranularity(Granularities.ALL)
.setDimensions(dimensions(
new DefaultDimensionSpec(
"countryName",
"d0",
ColumnType.STRING
),
new DefaultDimensionSpec(
"cityName",
"d1",
ColumnType.STRING
),
new DefaultDimensionSpec(
"channel",
"d2",
ColumnType.STRING
)
))
.setDimFilter(in("countryName", ImmutableList.of("Austria", "Republic of Korea")))
.setContext(multipleWorkerContext)
.build();

final AggregatorFactory[] aggs = {
new FilteredAggregatorFactory(new CountAggregatorFactory("w1"), notNull("d2"), "w1")
};

final WindowOperatorQuery windowQuery = new WindowOperatorQuery(
new QueryDataSource(groupByQuery),
new LegacySegmentSpec(Intervals.ETERNITY),
multipleWorkerContext,
RowSignature.builder()
.add("d0", ColumnType.STRING)
.add("d1", ColumnType.STRING)
.add("d2", ColumnType.STRING)
.add("w0", ColumnType.LONG)
.add("w1", ColumnType.LONG).build(),
ImmutableList.of(
new NaiveSortOperatorFactory(ImmutableList.of(ColumnWithDirection.ascending("d0"), ColumnWithDirection.ascending("d1"), ColumnWithDirection.ascending("d2"))),
new NaivePartitioningOperatorFactory(Collections.emptyList()),
new WindowOperatorFactory(new WindowRowNumberProcessor("w0")),
new NaiveSortOperatorFactory(ImmutableList.of(ColumnWithDirection.ascending("d1"), ColumnWithDirection.ascending("d0"), ColumnWithDirection.ascending("d2"))),
new NaivePartitioningOperatorFactory(Collections.singletonList("d1")),
new WindowOperatorFactory(new WindowFramedAggregateProcessor(WindowFrame.forOrderBy("d0", "d1", "d2"), aggs))
),
ImmutableList.of()
);

final ScanQuery scanQuery = Druids.newScanQueryBuilder()
.dataSource(new QueryDataSource(windowQuery))
.intervals(querySegmentSpec(Filtration.eternity()))
.columns("d0", "d1", "d2", "w0", "w1")
.orderBy(
ImmutableList.of(
new ScanQuery.OrderBy("d0", ScanQuery.Order.ASCENDING),
new ScanQuery.OrderBy("d1", ScanQuery.Order.ASCENDING),
new ScanQuery.OrderBy("d2", ScanQuery.Order.ASCENDING)
)
)
.columnTypes(ColumnType.STRING, ColumnType.STRING, ColumnType.STRING, ColumnType.LONG, ColumnType.LONG)
.limit(Long.MAX_VALUE)
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.context(contextWithRowSignature)
.build();

final String sql = "select countryName, cityName, channel, \n"
+ "row_number() over (order by countryName, cityName, channel) as c1, \n"
+ "count(channel) over (partition by cityName order by countryName, cityName, channel) as c2\n"
+ "from wikipedia\n"
+ "where countryName in ('Austria', 'Republic of Korea')\n"
+ "group by countryName, cityName, channel "
+ "order by countryName, cityName, channel";

testSelectQuery()
.setSql(sql)
.setExpectedMSQSpec(MSQSpec.builder()
.query(scanQuery)
.columnMappings(
new ColumnMappings(ImmutableList.of(
new ColumnMapping("d0", "countryName"),
new ColumnMapping("d1", "cityName"),
new ColumnMapping("d2", "channel"),
new ColumnMapping("w0", "c1"),
new ColumnMapping("w1", "c2")
)
))
.tuningConfig(MSQTuningConfig.defaultConfig())
.build())
.setExpectedRowSignature(rowSignature)
.setExpectedResultRows(
ImmutableList.<Object[]>of(
new Object[]{"Austria", null, "#de.wikipedia", 1L, 1L},
new Object[]{"Austria", "Horsching", "#de.wikipedia", 2L, 1L},
new Object[]{"Austria", "Vienna", "#de.wikipedia", 3L, 1L},
new Object[]{"Austria", "Vienna", "#es.wikipedia", 4L, 2L},
new Object[]{"Austria", "Vienna", "#tr.wikipedia", 5L, 3L},
new Object[]{"Republic of Korea", null, "#en.wikipedia", 6L, 2L},
new Object[]{"Republic of Korea", null, "#ja.wikipedia", 7L, 3L},
new Object[]{"Republic of Korea", null, "#ko.wikipedia", 8L, 4L},
new Object[]{"Republic of Korea", "Jeonju", "#ko.wikipedia", 9L, 1L},
new Object[]{"Republic of Korea", "Seongnam-si", "#ko.wikipedia", 10L, 1L},
new Object[]{"Republic of Korea", "Seoul", "#ko.wikipedia", 11L, 1L},
new Object[]{"Republic of Korea", "Suwon-si", "#ko.wikipedia", 12L, 1L},
new Object[]{"Republic of Korea", "Yongsan-dong", "#ko.wikipedia", 13L, 1L}
)
)
.setQueryContext(multipleWorkerContext)
.verifyResults();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import java.util.Collections;
import java.util.List;
import java.util.Objects;

public class WindowRowNumberProcessor implements Processor
{
Expand Down Expand Up @@ -137,4 +138,23 @@ public List<String> getOutputColumnNames()
{
return Collections.singletonList(outputColumn);
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
WindowRowNumberProcessor that = (WindowRowNumberProcessor) o;
return Objects.equals(outputColumn, that.outputColumn);
}

@Override
public int hashCode()
{
return Objects.hashCode(outputColumn);
}
}

0 comments on commit 6b4b7eb

Please sign in to comment.