Skip to content

Commit

Permalink
Add stop_long_running_jobs funcitonality to Scheduler
Browse files Browse the repository at this point in the history
This adds two tasks to scheduler. 1) Processing the finished jobs and computing the running average 2) Checking that the duration of still running jobs is bellow the threshold and kills those jobs otherwise.
  • Loading branch information
xjules committed Jan 3, 2024
1 parent 6d2f9ad commit ba865c9
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 26 deletions.
14 changes: 5 additions & 9 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,16 +223,12 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
event_creator(identifiers.EVTYPE_ENSEMBLE_STARTED, None)
)

if isinstance(queue, Scheduler):
queue.add_dispatch_information_to_jobs_file()
result = await queue.execute()
elif isinstance(queue, JobQueue):
min_required_realizations = (
self.min_required_realizations if self.stop_long_running else 0
)
queue.add_dispatch_information_to_jobs_file()
min_required_realizations = (
self.min_required_realizations if self.stop_long_running else 0
)

result = await queue.execute(min_required_realizations)
queue.add_dispatch_information_to_jobs_file()
result = await queue.execute(min_required_realizations)

except Exception:
logger.exception(
Expand Down
16 changes: 16 additions & 0 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
import time
import uuid
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -67,6 +68,8 @@ def __init__(self, scheduler: Scheduler, real: Realization) -> None:
self._scheduler: Scheduler = scheduler
self._callback_status_msg: str = ""
self._requested_max_submit: Optional[int] = None
self._start_time: Optional[float] = None
self._end_time: Optional[float] = None

@property
def iens(self) -> int:
Expand All @@ -76,6 +79,14 @@ def iens(self) -> int:
def driver(self) -> Driver:
return self._scheduler.driver

@property
def running_duration(self) -> float:
if self._start_time:
if self._end_time:
return self._end_time - self._start_time
return time.time() - self._start_time
return 0

async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
await sem.acquire()
timeout_task: Optional[asyncio.Task[None]] = None
Expand All @@ -88,6 +99,7 @@ async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:

await self._send(State.PENDING)
await self.started.wait()
self._start_time = time.time()

await self._send(State.RUNNING)
if self.real.max_runtime is not None and self.real.max_runtime > 0:
Expand Down Expand Up @@ -179,6 +191,10 @@ async def _send(self, state: State) -> None:
if state in (State.FAILED, State.ABORTED):
await self._handle_failure()

if state == State.COMPLETED:
self._end_time = time.time()
await self._scheduler.completed_jobs.put(self.iens)

status = STATE_TO_LEGACY[state]
event = CloudEvent(
{
Expand Down
48 changes: 37 additions & 11 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@
from collections import defaultdict
from dataclasses import asdict
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Dict,
MutableMapping,
Optional,
Sequence,
)
from typing import TYPE_CHECKING, Any, Dict, MutableMapping, Optional, Sequence

from pydantic.dataclasses import dataclass
from websockets import Headers
Expand Down Expand Up @@ -69,6 +62,10 @@ def __init__(
}

self._events: asyncio.Queue[Any] = asyncio.Queue()
self._average_job_runtime: float = 0
self._completed_jobs_num: int = 0
self.completed_jobs: asyncio.Queue[int] = asyncio.Queue()

self._cancelled = False
self._max_submit = max_submit
self._max_running = max_running
Expand All @@ -83,8 +80,29 @@ def kill_all_jobs(self) -> None:
for task in self._tasks.values():
task.cancel()

def stop_long_running_jobs(self, minimum_required_realizations: int) -> None:
pass
async def _update_avg_job_runtime(self) -> None:
while True:
iens = await self.completed_jobs.get()
self._average_job_runtime = (
self._average_job_runtime * self._completed_jobs_num
+ self._jobs[iens].running_duration
) / (self._completed_jobs_num + 1)
self._completed_jobs_num += 1

async def _stop_long_running_jobs(
self, minimum_required_realizations: int, long_running_factor: float = 1.25
) -> None:
while True:
if self._completed_jobs_num >= minimum_required_realizations:
for iens, task in self._tasks.items():
if (
self._jobs[iens].running_duration
> long_running_factor * self._average_job_runtime
and not task.done()
):
task.cancel()
await task
await asyncio.sleep(0.1)

def set_realization(self, realization: Realization) -> None:
self._jobs[realization.iens] = Job(self, realization)
Expand Down Expand Up @@ -126,11 +144,19 @@ def add_dispatch_information_to_jobs_file(self) -> None:
for job in self._jobs.values():
self._update_jobs_json(job.iens, job.real.run_arg.runpath)

async def execute(self, minimum_required_realizations: int = 0) -> str:
async def execute(
self,
min_required_realizations: int = 0,
) -> str:
async with background_tasks() as cancel_when_execute_is_done:
cancel_when_execute_is_done(self._publisher())
cancel_when_execute_is_done(self._process_event_queue())
cancel_when_execute_is_done(self.driver.poll())
if min_required_realizations > 0:
cancel_when_execute_is_done(
self._stop_long_running_jobs(min_required_realizations)
)
cancel_when_execute_is_done(self._update_avg_job_runtime())

start = asyncio.Event()
sem = asyncio.BoundedSemaphore(self._max_running)
Expand Down
19 changes: 13 additions & 6 deletions tests/unit_tests/scheduler/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,26 @@ def __init__(self, init=None, wait=None, kill=None):
self._mock_wait = wait
self._mock_kill = kill

async def _init(self, *args, **kwargs):
async def _init(self, iens, *args, **kwargs):
if self._mock_init is not None:
await self._mock_init(*args, **kwargs)
await self._mock_init(iens, *args, **kwargs)
return iens

async def _wait(self, *args):
async def _wait(self, iens):
if self._mock_wait is not None:
result = await self._mock_wait()
if self._mock_wait.__code__.co_argcount > 0:
result = await self._mock_wait(iens)
else:
result = await self._mock_wait()
return True if result is None else bool(result)
return True

async def _kill(self, *args):
async def _kill(self, iens, *args):
if self._mock_kill is not None:
await self._mock_kill()
if self._mock_kill.__code__.co_argcount > 0:
await self._mock_kill(iens)
else:
await self._mock_kill()


@pytest.fixture
Expand Down
31 changes: 31 additions & 0 deletions tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,34 @@ async def init(iens, *args, **kwargs):
assert sch.is_active()
await execute_task
assert not sch.is_active()


@pytest.mark.timeout(6)
async def test_that_long_running_jobs_were_stopped(storage, tmp_path, mock_driver):
killed_iens = []

async def kill(iens):
nonlocal killed_iens
killed_iens.append(iens)

async def wait(iens):
# all jobs with iens > 5 will sleep for 10 seconds and should be killed
if iens < 6:
await asyncio.sleep(0.1)
else:
await asyncio.sleep(10)
return True

ensemble_size = 10
ensemble = storage.create_experiment().create_ensemble(
name="foo", ensemble_size=ensemble_size
)
realizations = [
create_stub_realization(ensemble, tmp_path, iens)
for iens in range(ensemble_size)
]

sch = scheduler.Scheduler(mock_driver(wait=wait, kill=kill), realizations)

assert await sch.execute(min_required_realizations=5) == EVTYPE_ENSEMBLE_STOPPED
assert killed_iens == [6, 7, 8, 9]

0 comments on commit ba865c9

Please sign in to comment.