Skip to content

Commit

Permalink
Fix Concurrent Task Insertion in pendingCompletionTaskGroups (#16834)
Browse files Browse the repository at this point in the history
Fix streaming task failures that may arise due to concurrent task insertion in pendingCompletionTaskGroups
  • Loading branch information
hardikbajaj authored Aug 8, 2024
1 parent ceed4a0 commit 1cf3f4b
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2482,43 +2482,67 @@ private void verifyAndMergeCheckpoints(final TaskGroup taskGroup)
);
}

private void addDiscoveredTaskToPendingCompletionTaskGroups(
@VisibleForTesting
protected void addDiscoveredTaskToPendingCompletionTaskGroups(
int groupId,
String taskId,
Map<PartitionIdType, SequenceOffsetType> startingPartitions
)
{
final CopyOnWriteArrayList<TaskGroup> taskGroupList = pendingCompletionTaskGroups.computeIfAbsent(
final CopyOnWriteArrayList<TaskGroup> taskGroupList = pendingCompletionTaskGroups.compute(
groupId,
k -> new CopyOnWriteArrayList<>()
(k, val) -> {
// Creating new pending completion task groups while compute so that read and writes are locked.
// To ensure synchronisatoin across threads, we need to do updates in compute so that we get only one task group for all replica tasks
if (val == null) {
val = new CopyOnWriteArrayList<>();
}

boolean isTaskGroupPresent = false;
for (TaskGroup taskGroup : val) {
if (taskGroup.startingSequences.equals(startingPartitions)) {
isTaskGroupPresent = true;
break;
}
}
if (!isTaskGroupPresent) {
log.info("Creating new pending completion task group [%s] for discovered task [%s].", groupId, taskId);

// reading the minimumMessageTime & maximumMessageTime from the publishing task and setting it here is not necessary as this task cannot
// change to a state where it will read any more events.
// This is a discovered task, so it would not have been assigned closed partitions initially.
TaskGroup newTaskGroup = new TaskGroup(
groupId,
ImmutableMap.copyOf(startingPartitions),
null,
Optional.absent(),
Optional.absent(),
null
);

newTaskGroup.tasks.put(taskId, new TaskData());
newTaskGroup.completionTimeout = DateTimes.nowUtc().plus(ioConfig.getCompletionTimeout());

val.add(newTaskGroup);
}
return val;
}
);

for (TaskGroup taskGroup : taskGroupList) {
if (taskGroup.startingSequences.equals(startingPartitions)) {
if (taskGroup.tasks.putIfAbsent(taskId, new TaskData()) == null) {
log.info("Added discovered task [%s] to existing pending task group [%s]", taskId, groupId);
log.info("Added discovered task [%s] to existing pending completion task group [%s]. PendingCompletionTaskGroup: %s", taskId, groupId, taskGroup.taskIds());
}
return;
}
}
}

log.info("Creating new pending completion task group [%s] for discovered task [%s]", groupId, taskId);

// reading the minimumMessageTime & maximumMessageTime from the publishing task and setting it here is not necessary as this task cannot
// change to a state where it will read any more events.
// This is a discovered task, so it would not have been assigned closed partitions initially.
TaskGroup newTaskGroup = new TaskGroup(
groupId,
ImmutableMap.copyOf(startingPartitions),
null,
Optional.absent(),
Optional.absent(),
null
);

newTaskGroup.tasks.put(taskId, new TaskData());
newTaskGroup.completionTimeout = DateTimes.nowUtc().plus(ioConfig.getCompletionTimeout());

taskGroupList.add(newTaskGroup);
@VisibleForTesting
protected CopyOnWriteArrayList<TaskGroup> getPendingCompletionTaskGroups(int groupId)
{
return pendingCompletionTaskGroups.get(groupId);
}

// Sanity check to ensure that tasks have the same sequence name as their task group
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import org.apache.druid.java.util.common.ISE;
import org.apache.druid.java.util.common.Intervals;
import org.apache.druid.java.util.common.StringUtils;
import org.apache.druid.java.util.common.concurrent.Execs;
import org.apache.druid.java.util.common.granularity.Granularities;
import org.apache.druid.java.util.common.parsers.JSONPathSpec;
import org.apache.druid.java.util.metrics.DruidMonitorSchedulerConfig;
Expand Down Expand Up @@ -114,8 +115,13 @@
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -281,6 +287,181 @@ public void testRunningStreamGetSequenceNumberReturnsNull()
verifyAll();
}

@Test
public void testAddDiscoveredTaskToPendingCompletionTaskGroups() throws Exception
{
EasyMock.expect(spec.isSuspended()).andReturn(false).anyTimes();
EasyMock.expect(recordSupplier.getPartitionIds(STREAM)).andReturn(ImmutableSet.of(SHARD_ID)).anyTimes();
EasyMock.expect(taskStorage.getActiveTasksByDatasource(DATASOURCE)).andReturn(ImmutableList.of()).anyTimes();
EasyMock.expect(taskQueue.add(EasyMock.anyObject())).andReturn(true).anyTimes();

replayAll();
ExecutorService threadExecutor = Execs.multiThreaded(3, "my-thread-pool-%d");

SeekableStreamSupervisor supervisor = new TestSeekableStreamSupervisor();
Map<String, String> startingPartitions = new HashMap<>();
startingPartitions.put("partition", "offset");

// Test concurrent threads adding to same task group
Callable<Boolean> task1 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_1", startingPartitions);
return true;
};
Callable<Boolean> task2 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_2", startingPartitions);
return true;
};
Callable<Boolean> task3 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_3", startingPartitions);
return true;
};

// Create a list to hold the Callable tasks
List<Callable<Boolean>> tasks = new ArrayList<>();
tasks.add(task1);
tasks.add(task2);
tasks.add(task3);
List<Future<Boolean>> futures = threadExecutor.invokeAll(tasks);
// Wait for all tasks to complete
for (Future<Boolean> future : futures) {
try {
Boolean result = future.get();
Assert.assertTrue(result);
}
catch (ExecutionException e) {
Assert.fail();
}
}
CopyOnWriteArrayList<SeekableStreamSupervisor.TaskGroup> taskGroups = supervisor.getPendingCompletionTaskGroups(0);
Assert.assertEquals(1, taskGroups.size());
Assert.assertEquals(3, taskGroups.get(0).tasks.size());

// Test concurrent threads adding to different task groups
task1 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_1", startingPartitions);
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_1", startingPartitions);
return true;
};
task2 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(2, "task_1", startingPartitions);
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(2, "task_1", startingPartitions);
return true;
};
task3 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_2", startingPartitions);
return true;
};
Callable<Boolean> task4 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(2, "task_2", startingPartitions);
return true;
};
Callable<Boolean> task5 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_3", startingPartitions);
return true;
};
Callable<Boolean> task6 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(1, "task_1", startingPartitions);
return true;
};

tasks = new ArrayList<>();
tasks.add(task1);
tasks.add(task2);
tasks.add(task3);
tasks.add(task4);
tasks.add(task5);
tasks.add(task6);
futures = threadExecutor.invokeAll(tasks);
for (Future<Boolean> future : futures) {
try {
Boolean result = future.get();
Assert.assertTrue(result);
}
catch (ExecutionException e) {
Assert.fail();
}
}

taskGroups = supervisor.getPendingCompletionTaskGroups(1);
Assert.assertEquals(1, taskGroups.size());
Assert.assertEquals(3, taskGroups.get(0).tasks.size());

taskGroups = supervisor.getPendingCompletionTaskGroups(2);
Assert.assertEquals(1, taskGroups.size());
Assert.assertEquals(2, taskGroups.get(0).tasks.size());
}

@Test
public void testAddDiscoveredTaskToPendingCompletionMultipleTaskGroups() throws Exception
{
EasyMock.expect(spec.isSuspended()).andReturn(false).anyTimes();
EasyMock.expect(recordSupplier.getPartitionIds(STREAM)).andReturn(ImmutableSet.of(SHARD_ID)).anyTimes();
EasyMock.expect(taskStorage.getActiveTasksByDatasource(DATASOURCE)).andReturn(ImmutableList.of()).anyTimes();
EasyMock.expect(taskQueue.add(EasyMock.anyObject())).andReturn(true).anyTimes();

replayAll();

// Test adding tasks with same task group and different partition offsets.
SeekableStreamSupervisor supervisor = new TestSeekableStreamSupervisor();
ExecutorService threadExecutor = Execs.multiThreaded(3, "my-thread-pool-%d");
Map<String, String> startingPartiions = new HashMap<>();
startingPartiions.put("partition", "offset");

Map<String, String> startingPartiions1 = new HashMap<>();
startingPartiions.put("partition", "offset1");

Callable<Boolean> task1 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_1", startingPartiions);
return true;
};
Callable<Boolean> task2 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_2", startingPartiions);
return true;
};
Callable<Boolean> task3 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_3", startingPartiions);
return true;
};
Callable<Boolean> task4 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_7", startingPartiions1);
return true;
};
Callable<Boolean> task5 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_8", startingPartiions1);
return true;
};
Callable<Boolean> task6 = () -> {
supervisor.addDiscoveredTaskToPendingCompletionTaskGroups(0, "task_9", startingPartiions1);
return true;
};

List<Callable<Boolean>> tasks = new ArrayList<>();
tasks.add(task1);
tasks.add(task2);
tasks.add(task3);
tasks.add(task4);
tasks.add(task5);
tasks.add(task6);

List<Future<Boolean>> futures = threadExecutor.invokeAll(tasks);

for (Future<Boolean> future : futures) {
try {
Boolean result = future.get();
Assert.assertTrue(result);
}
catch (ExecutionException e) {
Assert.fail();
}
}

CopyOnWriteArrayList<SeekableStreamSupervisor.TaskGroup> taskGroups = supervisor.getPendingCompletionTaskGroups(0);

Assert.assertEquals(2, taskGroups.size());
Assert.assertEquals(3, taskGroups.get(0).tasks.size());
Assert.assertEquals(3, taskGroups.get(1).tasks.size());
}

@Test
public void testConnectingToStreamFail()
{
Expand Down

0 comments on commit 1cf3f4b

Please sign in to comment.