Skip to content

Commit

Permalink
Move some lifecycle management from doTask -> shutdown for the mm-les…
Browse files Browse the repository at this point in the history
…s task runner (#14895)

* save work

* Add syncronized

* Don't shutdown in run

* Adding unit tests

* Cleanup lifecycle

* Fix tests

* remove newline
  • Loading branch information
georgew5656 authored Aug 25, 2023
1 parent ad32f84 commit 95b0de6
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ protected synchronized TaskStatus run(Job job, long launchTimeout, long timeout)
}
catch (Exception e) {
log.info("Failed to run task: %s", taskId.getOriginalTaskId());
shutdown();
throw e;
}
finally {
Expand Down Expand Up @@ -168,10 +167,9 @@ protected synchronized TaskStatus join(long timeout) throws IllegalStateExceptio
finally {
try {
saveLogs();
shutdown();
}
catch (Exception e) {
log.warn(e, "Task [%s] cleanup failed", taskId);
log.warn(e, "Log processing failed for task [%s]", taskId);
}

stopTask();
Expand All @@ -188,7 +186,7 @@ protected synchronized TaskStatus join(long timeout) throws IllegalStateExceptio
*/
protected void shutdown()
{
if (State.PENDING.equals(state.get()) || State.RUNNING.equals(state.get())) {
if (State.PENDING.equals(state.get()) || State.RUNNING.equals(state.get()) || State.STOPPED.equals(state.get())) {
kubernetesClient.deletePeonJob(taskId);
}
}
Expand Down Expand Up @@ -223,7 +221,7 @@ protected State getState()
*/
protected TaskLocation getTaskLocation()
{
if (!State.RUNNING.equals(state.get())) {
if (State.PENDING.equals(state.get()) || State.NOT_STARTED.equals(state.get())) {
log.debug("Can't get task location for non-running job. [%s]", taskId.getOriginalTaskId());
return TaskLocation.unknown();
}
Expand Down Expand Up @@ -251,7 +249,6 @@ protected TaskLocation getTaskLocation()
Boolean.parseBoolean(pod.getMetadata().getAnnotations().getOrDefault(DruidK8sConstants.TLS_ENABLED, "false")),
pod.getMetadata() != null ? pod.getMetadata().getName() : ""
);
log.info("K8s task %s is running at location %s", taskId.getOriginalTaskId(), taskLocation);
}

return taskLocation;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,6 @@ protected TaskStatus doTask(Task task, boolean run)
KubernetesWorkItem workItem = tasks.get(task.getId());

if (workItem == null) {
throw new ISE("Task [%s] disappeared", task.getId());
}

if (workItem.isShutdownRequested()) {
throw new ISE("Task [%s] has been shut down", task.getId());
}

Expand Down Expand Up @@ -213,11 +209,6 @@ protected TaskStatus doTask(Task task, boolean run)
log.error(e, "Task [%s] execution caught an exception", task.getId());
throw new RuntimeException(e);
}
finally {
synchronized (tasks) {
tasks.remove(task.getId());
}
}
}

@VisibleForTesting
Expand Down Expand Up @@ -271,6 +262,10 @@ public void shutdown(String taskid, String reason)
return;
}

synchronized (tasks) {
tasks.remove(taskid);
}

workItem.shutdown();
}

Expand Down Expand Up @@ -440,6 +435,17 @@ public Collection<TaskRunnerWorkItem> getPendingTasks()
.collect(Collectors.toList());
}

@Override
public TaskLocation getTaskLocation(String taskId)
{
final KubernetesWorkItem workItem = tasks.get(taskId);
if (workItem == null) {
return TaskLocation.unknown();
} else {
return workItem.getLocation();
}
}

@Nullable
@Override
public RunnerTaskState getRunnerTaskState(String taskId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,10 @@
import org.apache.druid.java.util.common.ISE;

import java.io.InputStream;
import java.util.concurrent.atomic.AtomicBoolean;

public class KubernetesWorkItem extends TaskRunnerWorkItem
{
private final Task task;

private final AtomicBoolean shutdownRequested = new AtomicBoolean(false);
private KubernetesPeonLifecycle kubernetesPeonLifecycle = null;

public KubernetesWorkItem(Task task, ListenableFuture<TaskStatus> statusFuture)
Expand All @@ -53,19 +50,13 @@ protected synchronized void setKubernetesPeonLifecycle(KubernetesPeonLifecycle k

protected synchronized void shutdown()
{
this.shutdownRequested.set(true);

if (this.kubernetesPeonLifecycle != null) {
this.kubernetesPeonLifecycle.startWatchingLogs();
this.kubernetesPeonLifecycle.shutdown();
}
}

protected boolean isShutdownRequested()
{
return shutdownRequested.get();
}

protected boolean isPending()
{
return RunnerTaskState.PENDING.equals(getRunnerTaskState());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,6 @@ protected synchronized TaskStatus join(long timeout)
EasyMock.anyLong(),
EasyMock.eq(TimeUnit.MILLISECONDS)
)).andReturn(null);
EasyMock.expect(kubernetesClient.deletePeonJob(
new K8sTaskId(ID)
)).andReturn(true);
Assert.assertEquals(KubernetesPeonLifecycle.State.NOT_STARTED, peonLifecycle.getState());
stateListener.stateChanged(KubernetesPeonLifecycle.State.PENDING, ID);
EasyMock.expectLastCall().once();
Expand Down Expand Up @@ -245,7 +242,6 @@ public void test_join_withoutJob_returnsFailedTaskStatus() throws IOException
EasyMock.expectLastCall().once();
logWatch.close();
EasyMock.expectLastCall();
EasyMock.expect(kubernetesClient.deletePeonJob(k8sTaskId)).andReturn(true);

replayAll();

Expand Down Expand Up @@ -298,7 +294,6 @@ public void test_join() throws IOException
EasyMock.expectLastCall().once();
logWatch.close();
EasyMock.expectLastCall();
EasyMock.expect(kubernetesClient.deletePeonJob(k8sTaskId)).andReturn(true);

Assert.assertEquals(KubernetesPeonLifecycle.State.NOT_STARTED, peonLifecycle.getState());

Expand Down Expand Up @@ -353,7 +348,6 @@ public void test_join_whenCalledMultipleTimes_raisesIllegalStateException() thro
EasyMock.expectLastCall().once();
logWatch.close();
EasyMock.expectLastCall();
EasyMock.expect(kubernetesClient.deletePeonJob(k8sTaskId)).andReturn(true);

Assert.assertEquals(KubernetesPeonLifecycle.State.NOT_STARTED, peonLifecycle.getState());

Expand Down Expand Up @@ -408,7 +402,6 @@ public void test_join_withoutTaskStatus_returnsFailedTaskStatus() throws IOExcep
EasyMock.expectLastCall().once();
logWatch.close();
EasyMock.expectLastCall();
EasyMock.expect(kubernetesClient.deletePeonJob(k8sTaskId)).andReturn(true);

Assert.assertEquals(KubernetesPeonLifecycle.State.NOT_STARTED, peonLifecycle.getState());

Expand Down Expand Up @@ -459,7 +452,6 @@ public void test_join_whenIOExceptionThrownWhileStreamingTaskStatus_returnsFaile
EasyMock.expectLastCall().once();
logWatch.close();
EasyMock.expectLastCall();
EasyMock.expect(kubernetesClient.deletePeonJob(k8sTaskId)).andReturn(true);

Assert.assertEquals(KubernetesPeonLifecycle.State.NOT_STARTED, peonLifecycle.getState());

Expand Down Expand Up @@ -512,7 +504,6 @@ public void test_join_whenIOExceptionThrownWhileStreamingTaskLogs_isIgnored() th
EasyMock.expectLastCall().once();
logWatch.close();
EasyMock.expectLastCall();
EasyMock.expect(kubernetesClient.deletePeonJob(k8sTaskId)).andReturn(true);

Assert.assertEquals(KubernetesPeonLifecycle.State.NOT_STARTED, peonLifecycle.getState());

Expand Down Expand Up @@ -554,8 +545,6 @@ public void test_join_whenRuntimeExceptionThrownWhileWaitingForKubernetesJob_thr
logWatch.close();
EasyMock.expectLastCall();

EasyMock.expect(kubernetesClient.deletePeonJob(k8sTaskId)).andReturn(true);

Assert.assertEquals(KubernetesPeonLifecycle.State.NOT_STARTED, peonLifecycle.getState());

replayAll();
Expand Down Expand Up @@ -908,8 +897,11 @@ public void test_getTaskLocation_withStoppedTaskState_returnsUnknown()
stateListener
);
setPeonLifecycleState(peonLifecycle, KubernetesPeonLifecycle.State.STOPPED);
EasyMock.expect(kubernetesClient.getPeonPod(k8sTaskId.getK8sJobName())).andReturn(Optional.absent()).once();

replayAll();
Assert.assertEquals(TaskLocation.unknown(), peonLifecycle.getTaskLocation());
verifyAll();
}

private void setPeonLifecycleState(KubernetesPeonLifecycle peonLifecycle, KubernetesPeonLifecycle.State state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ public void test_run_withoutExistingTask() throws IOException, ExecutionExceptio
Assert.assertEquals(taskStatus, future.get());

verifyAll();

Assert.assertFalse(runner.tasks.containsKey(task.getId()));
}

@Test
Expand Down Expand Up @@ -191,8 +189,6 @@ public void test_run_whenExceptionThrown_throwsRuntimeException() throws IOExcep
Assert.assertTrue(e.getCause() instanceof RuntimeException);

verifyAll();

Assert.assertFalse(runner.tasks.containsKey(task.getId()));
}

@Test
Expand All @@ -208,8 +204,6 @@ public void test_join_withoutExistingTask() throws ExecutionException, Interrupt
Assert.assertEquals(taskStatus, future.get());

verifyAll();

Assert.assertFalse(runner.tasks.containsKey(task.getId()));
}

@Test
Expand All @@ -236,28 +230,11 @@ public void test_join_whenExceptionThrown_throwsRuntimeException()
Assert.assertTrue(e.getCause() instanceof RuntimeException);

verifyAll();

Assert.assertFalse(runner.tasks.containsKey(task.getId()));
}

@Test
public void test_doTask_withoutWorkItem_throwsRuntimeException()
{
Assert.assertThrows(
"Task [id] disappeared",
RuntimeException.class,
() -> runner.doTask(task, true)
);
}

@Test
public void test_doTask_whenShutdownRequested_throwsRuntimeException()
{
KubernetesWorkItem workItem = new KubernetesWorkItem(task, null);
workItem.shutdown();

runner.tasks.put(task.getId(), workItem);

Assert.assertThrows(
"Task [id] has been shut down",
RuntimeException.class,
Expand All @@ -266,13 +243,7 @@ public void test_doTask_whenShutdownRequested_throwsRuntimeException()
}

@Test
public void test_shutdown_withoutExistingTask()
{
runner.shutdown(task.getId(), "");
}

@Test
public void test_shutdown_withExistingTask()
public void test_shutdown_withExistingTask_removesTaskFromMap()
{
KubernetesWorkItem workItem = new KubernetesWorkItem(task, null) {
@Override
Expand All @@ -282,7 +253,13 @@ protected synchronized void shutdown()
};

runner.tasks.put(task.getId(), workItem);
runner.shutdown(task.getId(), "");
Assert.assertTrue(runner.tasks.isEmpty());
}

@Test
public void test_shutdown_withoutExistingTask()
{
runner.shutdown(task.getId(), "");
}

Expand Down Expand Up @@ -629,6 +606,30 @@ public TaskLocation getLocation()
verifyAll();
}

@Test
public void test_getTaskLocation_withExistingTask()
{
KubernetesWorkItem workItem = new KubernetesWorkItem(task, null) {
@Override
public TaskLocation getLocation()
{
return TaskLocation.create("host", 0, 1, false);
}
};

runner.tasks.put(task.getId(), workItem);

TaskLocation taskLocation = runner.getTaskLocation(task.getId());
Assert.assertEquals(TaskLocation.create("host", 0, 1, false), taskLocation);
}

@Test
public void test_getTaskLocation_noTaskFound()
{
TaskLocation taskLocation = runner.getTaskLocation(task.getId());
Assert.assertEquals(TaskLocation.unknown(), taskLocation);
}

@Test
public void test_getTotalCapacity()
{
Expand All @@ -644,6 +645,5 @@ public void test_getUsedCapacity()
Assert.assertEquals(1, runner.getUsedCapacity());
runner.tasks.remove(task.getId());
Assert.assertEquals(0, runner.getUsedCapacity());

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ public void test_setKubernetesPeonLifecycleTwice_throwsIllegalStateException()
public void test_shutdown_withoutKubernetesPeonLifecycle()
{
workItem.shutdown();
Assert.assertTrue(workItem.isShutdownRequested());
}

@Test
Expand All @@ -91,7 +90,6 @@ public void test_shutdown_withKubernetesPeonLifecycle()

workItem.shutdown();
verifyAll();
Assert.assertTrue(workItem.isShutdownRequested());
}

@Test
Expand Down

0 comments on commit 95b0de6

Please sign in to comment.