Skip to content

Commit

Permalink
Add "targetPartitionsPerWorker" setting for MSQ. (#17048)
Browse files Browse the repository at this point in the history
As we move towards multi-threaded MSQ workers, it helps for parallelism
to generate more than one partition per worker. That way, we can fully
utilize all worker threads throughout all stages.

The default value is the number of processing threads. Currently, this
is hard-coded to 1 for peons, but that is expected to change in the future.
  • Loading branch information
gianm authored Sep 13, 2024
1 parent 654e0b4 commit d3f86ba
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import org.apache.druid.msq.input.table.TableInputSpec;
import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.controller.ControllerQueryKernelConfig;
import org.apache.druid.msq.querykit.QueryKit;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.server.DruidNode;

/**
Expand Down Expand Up @@ -100,4 +102,10 @@ WorkerManager newWorkerManager(
* Client for communicating with workers.
*/
WorkerClient newWorkerClient();

/**
* Default target partitions per worker for {@link QueryKit#makeQueryDefinition}. Can be overridden using
* {@link MultiStageQueryContext#CTX_TARGET_PARTITIONS_PER_WORKER}.
*/
int defaultTargetPartitionsPerWorker();
}
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,16 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer)
this.netClient = new ExceptionWrappingWorkerClient(context.newWorkerClient());
closer.register(netClient);

final QueryContext queryContext = querySpec.getQuery().context();
final QueryDefinition queryDef = makeQueryDefinition(
queryId(),
makeQueryControllerToolKit(),
querySpec,
context.jsonMapper(),
MultiStageQueryContext.getTargetPartitionsPerWorkerWithDefault(
queryContext,
context.defaultTargetPartitionsPerWorker()
),
resultsContext
);

Expand Down Expand Up @@ -612,7 +617,7 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer)
);
}

final long maxParseExceptions = MultiStageQueryContext.getMaxParseExceptions(querySpec.getQuery().context());
final long maxParseExceptions = MultiStageQueryContext.getMaxParseExceptions(queryContext);
this.faultsExceededChecker = new FaultsExceededChecker(
ImmutableMap.of(CannotParseExternalDataFault.CODE, maxParseExceptions)
);
Expand All @@ -624,7 +629,7 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer)
stageDefinition.getId().getStageNumber(),
finalizeClusterStatisticsMergeMode(
stageDefinition,
MultiStageQueryContext.getClusterStatisticsMergeMode(querySpec.getQuery().context())
MultiStageQueryContext.getClusterStatisticsMergeMode(queryContext)
)
)
);
Expand Down Expand Up @@ -1718,17 +1723,18 @@ private static QueryDefinition makeQueryDefinition(
@SuppressWarnings("rawtypes") final QueryKit toolKit,
final MSQSpec querySpec,
final ObjectMapper jsonMapper,
final int targetPartitionsPerWorker,
final ResultsContext resultsContext
)
{
final MSQTuningConfig tuningConfig = querySpec.getTuningConfig();
final ColumnMappings columnMappings = querySpec.getColumnMappings();
final Query<?> queryToPlan;
final ShuffleSpecFactory shuffleSpecFactory;
final ShuffleSpecFactory resultShuffleSpecFactory;

if (MSQControllerTask.isIngestion(querySpec)) {
shuffleSpecFactory = querySpec.getDestination()
.getShuffleSpecFactory(tuningConfig.getRowsPerSegment());
resultShuffleSpecFactory = querySpec.getDestination()
.getShuffleSpecFactory(tuningConfig.getRowsPerSegment());

if (!columnMappings.hasUniqueOutputColumnNames()) {
// We do not expect to hit this case in production, because the SQL validator checks that column names
Expand All @@ -1752,7 +1758,7 @@ private static QueryDefinition makeQueryDefinition(
queryToPlan = querySpec.getQuery();
}
} else {
shuffleSpecFactory =
resultShuffleSpecFactory =
querySpec.getDestination()
.getShuffleSpecFactory(MultiStageQueryContext.getRowsPerPage(querySpec.getQuery().context()));
queryToPlan = querySpec.getQuery();
Expand All @@ -1765,8 +1771,9 @@ private static QueryDefinition makeQueryDefinition(
queryId,
queryToPlan,
toolKit,
shuffleSpecFactory,
resultShuffleSpecFactory,
tuningConfig.getMaxNumWorkers(),
targetPartitionsPerWorker,
0
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public class IndexerControllerContext implements ControllerContext
private final ServiceClientFactory clientFactory;
private final OverlordClient overlordClient;
private final ServiceMetricEvent.Builder metricBuilder;
private final MemoryIntrospector memoryIntrospector;

public IndexerControllerContext(
final MSQControllerTask task,
Expand All @@ -89,6 +90,7 @@ public IndexerControllerContext(
this.clientFactory = clientFactory;
this.overlordClient = overlordClient;
this.metricBuilder = new ServiceMetricEvent.Builder();
this.memoryIntrospector = injector.getInstance(MemoryIntrospector.class);
IndexTaskUtils.setTaskDimensions(metricBuilder, task);
}

Expand All @@ -98,7 +100,6 @@ public ControllerQueryKernelConfig queryKernelConfig(
final QueryDefinition queryDef
)
{
final MemoryIntrospector memoryIntrospector = injector.getInstance(MemoryIntrospector.class);
final ControllerMemoryParameters memoryParameters =
ControllerMemoryParameters.createProductionInstance(
memoryIntrospector,
Expand Down Expand Up @@ -200,6 +201,14 @@ public WorkerManager newWorkerManager(
);
}

@Override
public int defaultTargetPartitionsPerWorker()
{
// Assume tasks are symmetric: workers have the same number of processors available as a controller.
// Create one partition per processor per task, for maximum parallelism.
return memoryIntrospector.numProcessorsInJvm();
}

/**
* Helper method for {@link #queryKernelConfig(MSQSpec, QueryDefinition)}. Also used in tests.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ public class DataSourcePlan
* @param maxWorkerCount maximum number of workers for subqueries
* @param minStageNumber starting stage number for subqueries
* @param broadcast whether the plan should broadcast data for this datasource
* @param targetPartitionsPerWorker preferred number of partitions per worker for subqueries
*/
@SuppressWarnings("rawtypes")
public static DataSourcePlan forDataSource(
Expand All @@ -146,6 +147,7 @@ public static DataSourcePlan forDataSource(
@Nullable DimFilter filter,
@Nullable Set<String> filterFields,
final int maxWorkerCount,
final int targetPartitionsPerWorker,
final int minStageNumber,
final boolean broadcast
)
Expand Down Expand Up @@ -186,6 +188,7 @@ public static DataSourcePlan forDataSource(
(FilteredDataSource) dataSource,
querySegmentSpec,
maxWorkerCount,
targetPartitionsPerWorker,
minStageNumber,
broadcast
);
Expand All @@ -197,6 +200,7 @@ public static DataSourcePlan forDataSource(
(UnnestDataSource) dataSource,
querySegmentSpec,
maxWorkerCount,
targetPartitionsPerWorker,
minStageNumber,
broadcast
);
Expand All @@ -207,6 +211,7 @@ public static DataSourcePlan forDataSource(
queryId,
(QueryDataSource) dataSource,
maxWorkerCount,
targetPartitionsPerWorker,
minStageNumber,
broadcast,
queryContext
Expand All @@ -221,6 +226,7 @@ public static DataSourcePlan forDataSource(
filter,
filterFields,
maxWorkerCount,
targetPartitionsPerWorker,
minStageNumber,
broadcast
);
Expand All @@ -242,6 +248,7 @@ public static DataSourcePlan forDataSource(
filter,
filterFields,
maxWorkerCount,
targetPartitionsPerWorker,
minStageNumber,
broadcast
);
Expand All @@ -253,6 +260,7 @@ public static DataSourcePlan forDataSource(
(JoinDataSource) dataSource,
querySegmentSpec,
maxWorkerCount,
targetPartitionsPerWorker,
minStageNumber,
broadcast
);
Expand Down Expand Up @@ -418,6 +426,7 @@ private static DataSourcePlan forQuery(
final String queryId,
final QueryDataSource dataSource,
final int maxWorkerCount,
final int targetPartitionsPerWorker,
final int minStageNumber,
final boolean broadcast,
@Nullable final QueryContext parentContext
Expand All @@ -429,8 +438,9 @@ private static DataSourcePlan forQuery(
// outermost query, and setting it for the subquery makes us erroneously add bucketing where it doesn't belong.
dataSource.getQuery().withOverriddenContext(CONTEXT_MAP_NO_SEGMENT_GRANULARITY),
queryKit,
ShuffleSpecFactories.globalSortWithMaxPartitionCount(maxWorkerCount),
ShuffleSpecFactories.globalSortWithMaxPartitionCount(maxWorkerCount * targetPartitionsPerWorker),
maxWorkerCount,
targetPartitionsPerWorker,
minStageNumber
);

Expand All @@ -451,6 +461,7 @@ private static DataSourcePlan forFilteredDataSource(
final FilteredDataSource dataSource,
final QuerySegmentSpec querySegmentSpec,
final int maxWorkerCount,
final int targetPartitionsPerWorker,
final int minStageNumber,
final boolean broadcast
)
Expand All @@ -464,6 +475,7 @@ private static DataSourcePlan forFilteredDataSource(
null,
null,
maxWorkerCount,
targetPartitionsPerWorker,
minStageNumber,
broadcast
);
Expand Down Expand Up @@ -491,6 +503,7 @@ private static DataSourcePlan forUnnest(
final UnnestDataSource dataSource,
final QuerySegmentSpec querySegmentSpec,
final int maxWorkerCount,
final int targetPartitionsPerWorker,
final int minStageNumber,
final boolean broadcast
)
Expand All @@ -505,6 +518,7 @@ private static DataSourcePlan forUnnest(
null,
null,
maxWorkerCount,
targetPartitionsPerWorker,
minStageNumber,
broadcast
);
Expand Down Expand Up @@ -537,6 +551,7 @@ private static DataSourcePlan forUnion(
@Nullable DimFilter filter,
@Nullable Set<String> filterFields,
final int maxWorkerCount,
final int targetPartitionsPerWorker,
final int minStageNumber,
final boolean broadcast
)
Expand All @@ -559,6 +574,7 @@ private static DataSourcePlan forUnion(
filter,
filterFields,
maxWorkerCount,
targetPartitionsPerWorker,
Math.max(minStageNumber, subqueryDefBuilder.getNextStageNumber()),
broadcast
);
Expand Down Expand Up @@ -590,6 +606,7 @@ private static DataSourcePlan forBroadcastHashJoin(
@Nullable final DimFilter filter,
@Nullable final Set<String> filterFields,
final int maxWorkerCount,
final int targetPartitionsPerWorker,
final int minStageNumber,
final boolean broadcast
)
Expand All @@ -606,6 +623,7 @@ private static DataSourcePlan forBroadcastHashJoin(
filter,
filter == null ? null : DimFilterUtils.onlyBaseFields(filterFields, analysis),
maxWorkerCount,
targetPartitionsPerWorker,
Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber()),
broadcast
);
Expand All @@ -626,6 +644,7 @@ private static DataSourcePlan forBroadcastHashJoin(
null, // Don't push down query filters for right-hand side: needs some work to ensure it works properly.
null,
maxWorkerCount,
targetPartitionsPerWorker,
Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber()),
true // Always broadcast right-hand side of the join.
);
Expand Down Expand Up @@ -660,6 +679,7 @@ private static DataSourcePlan forSortMergeJoin(
final JoinDataSource dataSource,
final QuerySegmentSpec querySegmentSpec,
final int maxWorkerCount,
final int targetPartitionsPerWorker,
final int minStageNumber,
final boolean broadcast
)
Expand All @@ -682,6 +702,7 @@ private static DataSourcePlan forSortMergeJoin(
queryId,
(QueryDataSource) dataSource.getLeft(),
maxWorkerCount,
targetPartitionsPerWorker,
Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber()),
false,
null
Expand All @@ -696,6 +717,7 @@ private static DataSourcePlan forSortMergeJoin(
queryId,
(QueryDataSource) dataSource.getRight(),
maxWorkerCount,
targetPartitionsPerWorker,
Math.max(minStageNumber, subQueryDefBuilder.getNextStageNumber()),
false,
null
Expand All @@ -707,8 +729,9 @@ private static DataSourcePlan forSortMergeJoin(
((StageInputSpec) Iterables.getOnlyElement(leftPlan.getInputSpecs())).getStageNumber()
);

final int hashPartitionCount = maxWorkerCount * targetPartitionsPerWorker;
final List<KeyColumn> leftPartitionKey = partitionKeys.get(0);
leftBuilder.shuffleSpec(new HashShuffleSpec(new ClusterBy(leftPartitionKey, 0), maxWorkerCount));
leftBuilder.shuffleSpec(new HashShuffleSpec(new ClusterBy(leftPartitionKey, 0), hashPartitionCount));
leftBuilder.signature(QueryKitUtils.sortableSignature(leftBuilder.getSignature(), leftPartitionKey));

// Build up the right stage.
Expand All @@ -717,7 +740,7 @@ private static DataSourcePlan forSortMergeJoin(
);

final List<KeyColumn> rightPartitionKey = partitionKeys.get(1);
rightBuilder.shuffleSpec(new HashShuffleSpec(new ClusterBy(rightPartitionKey, 0), maxWorkerCount));
rightBuilder.shuffleSpec(new HashShuffleSpec(new ClusterBy(rightPartitionKey, 0), hashPartitionCount));
rightBuilder.signature(QueryKitUtils.sortableSignature(rightBuilder.getSignature(), rightPartitionKey));

// Compute join signature.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public QueryDefinition makeQueryDefinition(
QueryKit<Query<?>> toolKitForSubQueries,
ShuffleSpecFactory resultShuffleSpecFactory,
int maxWorkerCount,
int targetPartitionsPerWorker,
int minStageNumber
)
{
Expand All @@ -59,6 +60,7 @@ public QueryDefinition makeQueryDefinition(
this,
resultShuffleSpecFactory,
maxWorkerCount,
targetPartitionsPerWorker,
minStageNumber
);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ public interface QueryKit<QueryType extends Query<?>>
* @param minStageNumber lowest stage number to use for any generated stages. Useful if the resulting
* {@link QueryDefinition} is going to be added to an existing
* {@link org.apache.druid.msq.kernel.QueryDefinitionBuilder}.
* @param targetPartitionsPerWorker preferred number of partitions per worker for subqueries
*/
QueryDefinition makeQueryDefinition(
String queryId,
QueryType query,
QueryKit<Query<?>> toolKitForSubQueries,
ShuffleSpecFactory resultShuffleSpecFactory,
int maxWorkerCount,
int targetPartitionsPerWorker,
int minStageNumber
);
}
Loading

0 comments on commit d3f86ba

Please sign in to comment.