Skip to content

Commit

Permalink
MSQ: Use task context flag useConcurrentLocks to determine task lock …
Browse files Browse the repository at this point in the history
…type (apache#17193)
  • Loading branch information
kfaraz authored Sep 30, 2024
1 parent 15987f5 commit 28fead5
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 109 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.inject.Injector;
import org.apache.druid.indexing.common.TaskLockType;
import org.apache.druid.indexing.common.actions.TaskActionClient;
import org.apache.druid.java.util.common.io.Closer;
import org.apache.druid.java.util.emitter.service.ServiceEmitter;
Expand Down Expand Up @@ -83,6 +84,11 @@ public interface ControllerContext
*/
TaskActionClient taskActionClient();

/**
* Task lock type.
*/
TaskLockType taskLockType();

/**
* Provides services about workers: starting, canceling, obtaining status.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,7 @@ private List<SegmentIdWithShardSpec> generateSegmentIdsWithShardSpecs(
destination,
partitionBoundaries,
keyReader,
MultiStageQueryContext.validateAndGetTaskLockType(QueryContext.of(querySpec.getQuery().getContext()), false),
context.taskLockType(),
isStageOutputEmpty
);
}
Expand Down Expand Up @@ -1335,10 +1335,7 @@ private void publishAllSegments(
(DataSourceMSQDestination) querySpec.getDestination();
final Set<DataSegment> segmentsWithTombstones = new HashSet<>(segments);
int numTombstones = 0;
final TaskLockType taskLockType = MultiStageQueryContext.validateAndGetTaskLockType(
QueryContext.of(querySpec.getQuery().getContext()),
destination.isReplaceTimeChunks()
);
final TaskLockType taskLockType = context.taskLockType();

if (destination.isReplaceTimeChunks()) {
final List<Interval> intervalsToDrop = findIntervalsToDrop(Preconditions.checkNotNull(segments, "segments"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.google.inject.Injector;
import com.google.inject.Key;
import org.apache.druid.guice.annotations.Self;
import org.apache.druid.indexing.common.TaskLockType;
import org.apache.druid.indexing.common.TaskToolbox;
import org.apache.druid.indexing.common.actions.TaskActionClient;
import org.apache.druid.indexing.common.task.IndexTaskUtils;
Expand Down Expand Up @@ -168,6 +169,12 @@ public TaskActionClient taskActionClient()
return toolbox.getTaskActionClient();
}

@Override
public TaskLockType taskLockType()
{
return task.getTaskLockType();
}

@Override
public WorkerClient newWorkerClient()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.apache.druid.msq.indexing.destination.TaskReportMSQDestination;
import org.apache.druid.msq.util.MultiStageQueryContext;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
import org.apache.druid.rpc.ServiceClientFactory;
import org.apache.druid.rpc.StandardRetryPolicy;
import org.apache.druid.rpc.indexing.OverlordClient;
Expand Down Expand Up @@ -234,8 +235,7 @@ public boolean isReady(TaskActionClient taskActionClient) throws Exception
{
// If we're in replace mode, acquire locks for all intervals before declaring the task ready.
if (isIngestion(querySpec) && ((DataSourceMSQDestination) querySpec.getDestination()).isReplaceTimeChunks()) {
final TaskLockType taskLockType =
MultiStageQueryContext.validateAndGetTaskLockType(QueryContext.of(querySpec.getQuery().getContext()), true);
final TaskLockType taskLockType = getTaskLockType();
final List<Interval> intervals =
((DataSourceMSQDestination) querySpec.getDestination()).getReplaceTimeChunks();
log.debug(
Expand Down Expand Up @@ -306,6 +306,26 @@ public int getPriority()
return getContextValue(Tasks.PRIORITY_KEY, Tasks.DEFAULT_BATCH_INDEX_TASK_PRIORITY);
}

@Nullable
public TaskLockType getTaskLockType()
{
if (isIngestion(querySpec)) {
return MultiStageQueryContext.validateAndGetTaskLockType(
QueryContext.of(
// Use the task context and override with the query context
QueryContexts.override(
getContext(),
querySpec.getQuery().getContext()
)
),
((DataSourceMSQDestination) querySpec.getDestination()).isReplaceTimeChunks()
);
} else {
// Locks need to be acquired only if data is being ingested into a DataSource
return null;
}
}

private static String getDataSourceForTaskMetadata(final MSQSpec querySpec)
{
final MSQDestination destination = querySpec.getDestination();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.druid.indexing.common.actions.TaskAction;
import org.apache.druid.indexing.common.actions.TaskActionClient;
import org.apache.druid.indexing.common.actions.TimeChunkLockTryAcquireAction;
import org.apache.druid.indexing.common.task.Tasks;
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.granularity.Granularities;
Expand All @@ -46,85 +47,56 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;

public class MSQControllerTaskTest
{
private final List<Interval> INTERVALS =
Collections.singletonList(Intervals.of(
"2011-04-01T00:00:00.000Z/2011-04-03T00:00:00.000Z"));

private final MSQSpec MSQ_SPEC = MSQSpec
.builder()
.destination(new DataSourceMSQDestination(
"target",
Granularities.DAY,
null,
INTERVALS,
null,
null
))
.query(new Druids.ScanQueryBuilder()
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.intervals(new MultipleIntervalSegmentSpec(INTERVALS))
.dataSource("target")
.build()
)
.columnMappings(new ColumnMappings(ImmutableList.of(new ColumnMapping("a0", "cnt"))))
.tuningConfig(MSQTuningConfig.defaultConfig())
.build();
private static final List<Interval> INTERVALS = Collections.singletonList(
Intervals.of("2011-04-01/2011-04-03")
);

private static MSQSpec.Builder msqSpecBuilder()
{
return MSQSpec
.builder()
.destination(
new DataSourceMSQDestination("target", Granularities.DAY, null, INTERVALS, null, null)
)
.query(
new Druids.ScanQueryBuilder()
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.intervals(new MultipleIntervalSegmentSpec(INTERVALS))
.dataSource("target")
.build()
)
.columnMappings(new ColumnMappings(ImmutableList.of(new ColumnMapping("a0", "cnt"))))
.tuningConfig(MSQTuningConfig.defaultConfig());
}

@Test
public void testGetInputSourceResources()
{
MSQControllerTask controllerTask = new MSQControllerTask(
null,
MSQ_SPEC,
null,
null,
null,
null,
null,
null
);
Assert.assertTrue(controllerTask.getInputSourceResources().isEmpty());
Assert.assertTrue(createControllerTask(msqSpecBuilder()).getInputSourceResources().isEmpty());
}

@Test
public void testGetDefaultLookupLoadingSpec()
{
MSQControllerTask controllerTask = new MSQControllerTask(
null,
MSQ_SPEC,
null,
null,
null,
null,
null,
null
);
MSQControllerTask controllerTask = createControllerTask(msqSpecBuilder());
Assert.assertEquals(LookupLoadingSpec.NONE, controllerTask.getLookupLoadingSpec());
}

@Test
public void testGetDefaultBroadcastDatasourceLoadingSpec()
{
MSQControllerTask controllerTask = new MSQControllerTask(
null,
MSQ_SPEC,
null,
null,
null,
null,
null,
null
);
MSQControllerTask controllerTask = createControllerTask(msqSpecBuilder());
Assert.assertEquals(BroadcastDatasourceLoadingSpec.NONE, controllerTask.getBroadcastDatasourceLoadingSpec());
}

@Test
public void testGetLookupLoadingSpecUsingLookupLoadingInfoInContext()
{
MSQSpec build = MSQSpec
MSQSpec.Builder builder = MSQSpec
.builder()
.query(new Druids.ScanQueryBuilder()
.intervals(new MultipleIntervalSegmentSpec(INTERVALS))
Expand All @@ -137,54 +109,83 @@ public void testGetLookupLoadingSpecUsingLookupLoadingInfoInContext()
.build()
)
.columnMappings(new ColumnMappings(Collections.emptyList()))
.tuningConfig(MSQTuningConfig.defaultConfig())
.build();
MSQControllerTask controllerTask = new MSQControllerTask(
null,
build,
null,
null,
null,
null,
null,
null
);
.tuningConfig(MSQTuningConfig.defaultConfig());

// Va;idate that MSQ Controller task doesn't load any lookups even if context has lookup info populated.
Assert.assertEquals(LookupLoadingSpec.NONE, controllerTask.getLookupLoadingSpec());
// Validate that MSQ Controller task doesn't load any lookups even if context has lookup info populated.
Assert.assertEquals(LookupLoadingSpec.NONE, createControllerTask(builder).getLookupLoadingSpec());
}

@Test
public void testGetTaskAllocatorId()
{
final String taskId = "taskId";
MSQControllerTask controllerTask = new MSQControllerTask(
taskId,
MSQ_SPEC,
MSQControllerTask controllerTask = createControllerTask(msqSpecBuilder());
Assert.assertEquals(controllerTask.getId(), controllerTask.getTaskAllocatorId());
}

@Test
public void testGetTaskLockType()
{
final DataSourceMSQDestination appendDestination
= new DataSourceMSQDestination("target", Granularities.DAY, null, null, null, null);
Assert.assertEquals(
TaskLockType.SHARED,
createControllerTask(msqSpecBuilder().destination(appendDestination)).getTaskLockType()
);

final DataSourceMSQDestination replaceDestination
= new DataSourceMSQDestination("target", Granularities.DAY, null, INTERVALS, null, null);
Assert.assertEquals(
TaskLockType.EXCLUSIVE,
createControllerTask(msqSpecBuilder().destination(replaceDestination)).getTaskLockType()
);

// With 'useConcurrentLocks' in task context
final Map<String, Object> taskContext = Collections.singletonMap(Tasks.USE_CONCURRENT_LOCKS, true);
final MSQControllerTask appendTaskWithContext = new MSQControllerTask(
null,
msqSpecBuilder().destination(appendDestination).build(),
null,
null,
null,
null,
null,
null
taskContext
);
Assert.assertEquals(taskId, controllerTask.getTaskAllocatorId());
}
Assert.assertEquals(TaskLockType.APPEND, appendTaskWithContext.getTaskLockType());

@Test
public void testIsReady() throws Exception
{
final String taskId = "taskId";
MSQControllerTask controllerTask = new MSQControllerTask(
taskId,
MSQ_SPEC,
final MSQControllerTask replaceTaskWithContext = new MSQControllerTask(
null,
msqSpecBuilder().destination(replaceDestination).build(),
null,
null,
null,
null,
null,
null
taskContext
);
Assert.assertEquals(TaskLockType.REPLACE, replaceTaskWithContext.getTaskLockType());

// With 'useConcurrentLocks' in query context
final Map<String, Object> queryContext = Collections.singletonMap(Tasks.USE_CONCURRENT_LOCKS, true);
final ScanQuery query = new Druids.ScanQueryBuilder()
.resultFormat(ScanQuery.ResultFormat.RESULT_FORMAT_COMPACTED_LIST)
.intervals(new MultipleIntervalSegmentSpec(INTERVALS))
.dataSource("target")
.context(queryContext)
.build();
Assert.assertEquals(
TaskLockType.APPEND,
createControllerTask(msqSpecBuilder().query(query).destination(appendDestination)).getTaskLockType()
);
Assert.assertEquals(
TaskLockType.REPLACE,
createControllerTask(msqSpecBuilder().query(query).destination(replaceDestination)).getTaskLockType()
);
}

@Test
public void testIsReady() throws Exception
{
TestTaskActionClient taskActionClient = new TestTaskActionClient(
new TimeChunkLock(
TaskLockType.REPLACE,
Expand All @@ -195,24 +196,14 @@ public void testIsReady() throws Exception
0
)
);
Assert.assertTrue(controllerTask.isReady(taskActionClient));
Assert.assertTrue(createControllerTask(msqSpecBuilder()).isReady(taskActionClient));
}

@Test
public void testIsReadyWithRevokedLock()
{
final String taskId = "taskId";
MSQControllerTask controllerTask = new MSQControllerTask(
taskId,
MSQ_SPEC,
null,
null,
null,
null,
null,
null
);
TestTaskActionClient taskActionClient = new TestTaskActionClient(
MSQControllerTask controllerTask = createControllerTask(msqSpecBuilder());
TaskActionClient taskActionClient = new TestTaskActionClient(
new TimeChunkLock(
TaskLockType.REPLACE,
"groupId",
Expand All @@ -225,10 +216,17 @@ public void testIsReadyWithRevokedLock()
);
DruidException exception = Assert.assertThrows(
DruidException.class,
() -> controllerTask.isReady(taskActionClient));
() -> controllerTask.isReady(taskActionClient)
);
Assert.assertEquals(
"Lock of type[REPLACE] for interval[2011-04-01T00:00:00.000Z/2011-04-03T00:00:00.000Z] was revoked",
exception.getMessage());
exception.getMessage()
);
}

private static MSQControllerTask createControllerTask(MSQSpec.Builder specBuilder)
{
return new MSQControllerTask("controller_1", specBuilder.build(), null, null, null, null, null, null, null);
}

private static class TestTaskActionClient implements TaskActionClient
Expand Down
Loading

0 comments on commit 28fead5

Please sign in to comment.