Skip to content

Commit

Permalink
Compute broadcast-join segmentMapFn only once per worker. (apache#15007)
Browse files Browse the repository at this point in the history
This patch introduces "processor managers" to processor factories, as a replacement for the sequence of processors. Processor managers can use the results of earlier processors to influence the creation of later processors, which provides us with the building block we need to ensure that broadcast join data is only read once.

In particular, when broadcast join is happening, the BaseFrameProcessorFactory now uses a ChainedProcessorManager to first run BroadcastJoinSegmentMapFnProcessor (in a single thread), and then run all of the regular processors (possibly multithreaded).
  • Loading branch information
gianm authored and ektravel committed Oct 16, 2023
1 parent b317b0b commit fd737b9
Show file tree
Hide file tree
Showing 43 changed files with 2,086 additions and 779 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ public void resultsComplete(
try {
convertedResultObject = context.jsonMapper().convertValue(
resultObject,
queryKernel.getStageDefinition(stageId).getProcessorFactory().getAccumulatedResultTypeReference()
queryKernel.getStageDefinition(stageId).getProcessorFactory().getResultTypeReference()
);
}
catch (IllegalArgumentException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
import org.apache.druid.frame.processor.PartitionedOutputChannel;
import org.apache.druid.frame.processor.SuperSorter;
import org.apache.druid.frame.processor.SuperSorterProgressTracker;
import org.apache.druid.frame.processor.manager.ProcessorManager;
import org.apache.druid.frame.processor.manager.ProcessorManagers;
import org.apache.druid.frame.util.DurableStorageUtils;
import org.apache.druid.frame.write.FrameWriters;
import org.apache.druid.indexer.TaskStatus;
Expand All @@ -71,8 +73,6 @@
import org.apache.druid.java.util.common.RE;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.UOE;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.java.util.common.guava.Sequences;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.msq.counters.CounterNames;
Expand Down Expand Up @@ -1100,11 +1100,12 @@ private void makeInputSliceReader()
.put(ExternalInputSlice.class, new ExternalInputSliceReader(frameContext.tempDir()))
.put(InlineInputSlice.class, new InlineInputSliceReader(frameContext.segmentWrangler()))
.put(LookupInputSlice.class, new LookupInputSliceReader(frameContext.segmentWrangler()))
.put(SegmentsInputSlice.class,
new SegmentsInputSliceReader(
frameContext.dataSegmentProvider(),
MultiStageQueryContext.isReindex(QueryContext.of(task().getContext()))
)
.put(
SegmentsInputSlice.class,
new SegmentsInputSliceReader(
frameContext.dataSegmentProvider(),
MultiStageQueryContext.isReindex(QueryContext.of(task().getContext()))
)
)
.build()
);
Expand Down Expand Up @@ -1152,7 +1153,16 @@ private void makeShuffleOutputChannelFactory(boolean isFinalStage)
);
}

private <FactoryType extends FrameProcessorFactory<I, WorkerClass, T, R>, I, WorkerClass extends FrameProcessor<T>, T, R> void makeAndRunWorkProcessors()
/**
* Use {@link FrameProcessorFactory#makeProcessors} to create {@link ProcessorsAndChannels}. Executes the
* processors using {@link #exec} and sets the output channels in {@link #workResultAndOutputChannels}.
*
* @param <FactoryType> type of {@link StageDefinition#getProcessorFactory()}
* @param <ProcessorReturnType> return type of {@link FrameProcessor} created by the manager
* @param <ManagerReturnType> result type of {@link ProcessorManager#result()}
* @param <ExtraInfoType> type of {@link WorkOrder#getExtraInfo()}
*/
private <FactoryType extends FrameProcessorFactory<ProcessorReturnType, ManagerReturnType, ExtraInfoType>, ProcessorReturnType, ManagerReturnType, ExtraInfoType> void makeAndRunWorkProcessors()
throws IOException
{
if (workResultAndOutputChannels != null) {
Expand All @@ -1163,21 +1173,21 @@ private <FactoryType extends FrameProcessorFactory<I, WorkerClass, T, R>, I, Wor
final FactoryType processorFactory = (FactoryType) kernel.getStageDefinition().getProcessorFactory();

@SuppressWarnings("unchecked")
final ProcessorsAndChannels<WorkerClass, T> processors =
final ProcessorsAndChannels<ProcessorReturnType, ManagerReturnType> processors =
processorFactory.makeProcessors(
kernel.getStageDefinition(),
kernel.getWorkOrder().getWorkerNumber(),
kernel.getWorkOrder().getInputs(),
inputSliceReader,
(I) kernel.getWorkOrder().getExtraInfo(),
(ExtraInfoType) kernel.getWorkOrder().getExtraInfo(),
workOutputChannelFactory,
frameContext,
parallelism,
counterTracker,
e -> warningPublisher.publishException(kernel.getStageDefinition().getStageNumber(), e)
);

final Sequence<WorkerClass> processorSequence = processors.processors();
final ProcessorManager<ProcessorReturnType, ManagerReturnType> processorManager = processors.getProcessorManager();

final int maxOutstandingProcessors;

Expand All @@ -1190,10 +1200,8 @@ private <FactoryType extends FrameProcessorFactory<I, WorkerClass, T, R>, I, Wor
Math.max(1, Math.min(parallelism, processors.getOutputChannels().getAllChannels().size()));
}

final ListenableFuture<R> workResultFuture = exec.runAllFully(
processorSequence,
processorFactory.newAccumulatedResult(),
processorFactory::accumulateResult,
final ListenableFuture<ManagerReturnType> workResultFuture = exec.runAllFully(
processorManager,
maxOutstandingProcessors,
processorBouncer,
cancellationId
Expand Down Expand Up @@ -1716,11 +1724,13 @@ private ResultAndChannels<?> gatherResultKeyStatistics(final OutputChannels chan

final ListenableFuture<ClusterByStatisticsCollector> clusterByStatisticsCollectorFuture =
exec.runAllFully(
Sequences.simple(processors),
stageDefinition.createResultKeyStatisticsCollector(
frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes()
),
ClusterByStatisticsCollector::addAll,
ProcessorManagers.of(processors)
.withAccumulation(
stageDefinition.createResultKeyStatisticsCollector(
frameContext.memoryParameters().getPartitionStatisticsMaxRetainedBytes()
),
ClusterByStatisticsCollector::addAll
),
// Run all processors simultaneously. They are lightweight and this keeps things moving.
processors.size(),
Bouncer.unlimited(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.druid.msq.input.InputSpecs;
import org.apache.druid.msq.kernel.QueryDefinition;
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.querykit.BroadcastJoinSegmentMapFnProcessor;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollectorImpl;
import org.apache.druid.query.lookup.LookupExtractor;
import org.apache.druid.query.lookup.LookupExtractorFactoryContainer;
Expand Down Expand Up @@ -130,11 +131,11 @@ public class WorkerMemoryParameters
private static final long SMALL_WORKER_CAPACITY_THRESHOLD_BYTES = 256_000_000;

/**
* Fraction of free memory per bundle that can be used by {@link org.apache.druid.msq.querykit.BroadcastJoinHelper}
* to store broadcast data on-heap. This is used to limit the total size of input frames, which we expect to
* expand on-heap. Expansion can potentially be somewhat over 2x: for example, strings are UTF-8 in frames, but are
* UTF-16 on-heap, which is a 2x expansion, and object and index overhead must be considered on top of that. So
* we use a value somewhat lower than 0.5.
* Fraction of free memory per bundle that can be used by {@link BroadcastJoinSegmentMapFnProcessor} to store broadcast
* data on-heap. This is used to limit the total size of input frames, which we expect to expand on-heap. Expansion
* can potentially be somewhat over 2x: for example, strings are UTF-8 in frames, but are UTF-16 on-heap, which is
* a 2x expansion, and object and index overhead must be considered on top of that. So we use a value somewhat
* lower than 0.5.
*/
static final double BROADCAST_JOIN_MEMORY_FRACTION = 0.3;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.google.common.collect.Iterables;
import org.apache.druid.frame.processor.OutputChannelFactory;
import org.apache.druid.frame.processor.OutputChannels;
import org.apache.druid.frame.processor.manager.ProcessorManagers;
import org.apache.druid.indexer.partitions.DynamicPartitionsSpec;
import org.apache.druid.indexer.partitions.PartitionsSpec;
import org.apache.druid.java.util.common.Pair;
Expand Down Expand Up @@ -75,7 +76,7 @@

@JsonTypeName("segmentGenerator")
public class SegmentGeneratorFrameProcessorFactory
implements FrameProcessorFactory<List<SegmentIdWithShardSpec>, SegmentGeneratorFrameProcessor, DataSegment, Set<DataSegment>>
implements FrameProcessorFactory<DataSegment, Set<DataSegment>, List<SegmentIdWithShardSpec>>
{
private final DataSchema dataSchema;
private final ColumnMappings columnMappings;
Expand Down Expand Up @@ -112,7 +113,7 @@ public MSQTuningConfig getTuningConfig()
}

@Override
public ProcessorsAndChannels<SegmentGeneratorFrameProcessor, DataSegment> makeProcessors(
public ProcessorsAndChannels<DataSegment, Set<DataSegment>> makeProcessors(
StageDefinition stageDefinition,
int workerNumber,
List<InputSlice> inputSlices,
Expand Down Expand Up @@ -151,7 +152,8 @@ public Pair<Integer, ReadableInput> apply(ReadableInput readableInput)
}
));
final SegmentGenerationProgressCounter segmentGenerationProgressCounter = counters.segmentGenerationProgress();
final SegmentGeneratorMetricsWrapper segmentGeneratorMetricsWrapper = new SegmentGeneratorMetricsWrapper(segmentGenerationProgressCounter);
final SegmentGeneratorMetricsWrapper segmentGeneratorMetricsWrapper =
new SegmentGeneratorMetricsWrapper(segmentGenerationProgressCounter);

final Sequence<SegmentGeneratorFrameProcessor> workers = inputSequence.map(
readableInputPair -> {
Expand Down Expand Up @@ -196,32 +198,28 @@ public Pair<Integer, ReadableInput> apply(ReadableInput readableInput)
}
);

return new ProcessorsAndChannels<>(workers, OutputChannels.none());
return new ProcessorsAndChannels<>(
ProcessorManagers.of(workers)
.withAccumulation(
new HashSet<>(),
(acc, segment) -> {
if (segment != null) {
acc.add(segment);
}

return acc;
}
),
OutputChannels.none()
);
}

@Override
public TypeReference<Set<DataSegment>> getAccumulatedResultTypeReference()
public TypeReference<Set<DataSegment>> getResultTypeReference()
{
return new TypeReference<Set<DataSegment>>() {};
}

@Override
public Set<DataSegment> newAccumulatedResult()
{
return new HashSet<>();
}

@Nullable
@Override
public Set<DataSegment> accumulateResult(Set<DataSegment> accumulated, DataSegment current)
{
if (current != null) {
accumulated.add(current);
}

return accumulated;
}

@Nullable
@Override
public Set<DataSegment> mergeAccumulatedResult(Set<DataSegment> accumulated, Set<DataSegment> otherAccumulated)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import com.fasterxml.jackson.annotation.JsonTypeInfo;
import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.processor.OutputChannelFactory;
import org.apache.druid.frame.processor.manager.ProcessorManager;
import org.apache.druid.msq.counters.CounterTracker;
import org.apache.druid.msq.input.InputSlice;
import org.apache.druid.msq.input.InputSliceReader;
Expand All @@ -36,18 +36,17 @@
* Property of {@link StageDefinition} that describes its computation logic.
*
* Workers call {@link #makeProcessors} to generate the processors that perform computations within that worker's
* {@link org.apache.druid.frame.processor.FrameProcessorExecutor}. Additionally, provides methods for accumulating
* the results of the processors: {@link #newAccumulatedResult()}, {@link #accumulateResult}, and
* {@link #mergeAccumulatedResult}.
* {@link org.apache.druid.frame.processor.FrameProcessorExecutor}. Additionally, provides
* {@link #mergeAccumulatedResult(Object, Object)} for merging results from {@link ProcessorManager#result()}.
*/
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
public interface FrameProcessorFactory<ExtraInfoType, ProcessorType extends FrameProcessor<T>, T, R>
public interface FrameProcessorFactory<T, R, ExtraInfoType>
{
/**
* Create processors for a particular worker in a particular stage. The processors will be run on a thread pool,
* with at most "maxOutstandingProcessors" number of processors outstanding at once.
*
* The Sequence returned by {@link ProcessorsAndChannels#processors()} is passed directly to
* The Sequence returned by {@link ProcessorsAndChannels#getProcessorManager()} is passed directly to
* {@link org.apache.druid.frame.processor.FrameProcessorExecutor#runAllFully}.
*
* @param stageDefinition stage definition
Expand All @@ -65,7 +64,7 @@ public interface FrameProcessorFactory<ExtraInfoType, ProcessorType extends Fram
*
* @return a processor sequence, which may be computed lazily; and a list of output channels.
*/
ProcessorsAndChannels<ProcessorType, T> makeProcessors(
ProcessorsAndChannels<T, R> makeProcessors(
StageDefinition stageDefinition,
int workerNumber,
List<InputSlice> inputSlices,
Expand All @@ -78,18 +77,8 @@ ProcessorsAndChannels<ProcessorType, T> makeProcessors(
Consumer<Throwable> warningPublisher
) throws IOException;

TypeReference<R> getAccumulatedResultTypeReference();

/**
* Produces a "blank slate" result.
*/
R newAccumulatedResult();

/**
* Accumulates an additional result. May modify the left-hand side {@code accumulated}. Does not modify the
* right-hand side {@code current}.
*/
R accumulateResult(R accumulated, T current);
@Nullable
TypeReference<R> getResultTypeReference();

/**
* Merges two accumulated results. May modify the left-hand side {@code accumulated}. Does not modify the right-hand
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,34 @@

package org.apache.druid.msq.kernel;

import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.frame.processor.OutputChannels;
import org.apache.druid.java.util.common.guava.Sequence;
import org.apache.druid.frame.processor.manager.ProcessorManager;

/**
* Returned from {@link FrameProcessorFactory#makeProcessors}.
*
* Includes a processor sequence and a list of output channels.
*
* @param <T> return type of {@link org.apache.druid.frame.processor.FrameProcessor} from {@link #getProcessorManager()}
* @param <R> result type of {@link ProcessorManager#result()}
*/
public class ProcessorsAndChannels<ProcessorClass extends FrameProcessor<T>, T>
public class ProcessorsAndChannels<T, R>
{
private final Sequence<ProcessorClass> workers;
private final ProcessorManager<T, R> processorManager;
private final OutputChannels outputChannels;

public ProcessorsAndChannels(
final Sequence<ProcessorClass> workers,
final ProcessorManager<T, R> processorManager,
final OutputChannels outputChannels
)
{
this.workers = workers;
this.processorManager = processorManager;
this.outputChannels = outputChannels;
}

public Sequence<ProcessorClass> processors()
public ProcessorManager<T, R> getProcessorManager()
{
return workers;
return processorManager;
}

public OutputChannels getOutputChannels()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
package org.apache.druid.msq.querykit;

import com.fasterxml.jackson.core.type.TypeReference;
import org.apache.druid.frame.processor.FrameProcessor;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.msq.kernel.ExtraInfoHolder;
import org.apache.druid.msq.kernel.FrameProcessorFactory;
Expand All @@ -30,30 +29,17 @@

/**
* Basic abstract {@link FrameProcessorFactory} that yields workers that do not require extra info and that
* always return Longs. This base class isn't used for every worker factory, but it is used for many of them.
* ignore the return values of their processors. This base class isn't used for every worker factory, but it is used
* for many of them.
*/
public abstract class BaseFrameProcessorFactory
implements FrameProcessorFactory<Object, FrameProcessor<Long>, Long, Long>
public abstract class BaseFrameProcessorFactory implements FrameProcessorFactory<Object, Long, Object>
{
@Override
public TypeReference<Long> getAccumulatedResultTypeReference()
public TypeReference<Long> getResultTypeReference()
{
return new TypeReference<Long>() {};
}

@Override
public Long newAccumulatedResult()
{
return 0L;
}

@Nullable
@Override
public Long accumulateResult(Long accumulated, Long current)
{
return accumulated + current;
}

@Override
public Long mergeAccumulatedResult(Long accumulated, Long otherAccumulated)
{
Expand Down
Loading

0 comments on commit fd737b9

Please sign in to comment.