Skip to content

Commit

Permalink
Remove callbacks from job_queue_node.
Browse files Browse the repository at this point in the history
Removed callbacks and split callback_arguments into run_arg and
ensemble_config.
Updated tests.
Inlined exit callback into job_queue_node.
  • Loading branch information
JHolba committed Sep 20, 2023
1 parent 1a4d3f4 commit bc2f62f
Show file tree
Hide file tree
Showing 15 changed files with 142 additions and 128 deletions.
10 changes: 1 addition & 9 deletions src/ert/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,14 @@
import logging
import time
from pathlib import Path
from typing import Callable, Iterable, Tuple
from typing import Iterable

from ert.config import EnsembleConfig, ParameterConfig, SummaryConfig
from ert.run_arg import RunArg

from .load_status import LoadResult, LoadStatus
from .realization_state import RealizationState

CallbackArgs = Tuple[RunArg, EnsembleConfig]
Callback = Callable[[RunArg, EnsembleConfig], LoadResult]

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -102,8 +99,3 @@ def forward_model_ok(
)

return final_result


def forward_model_exit(run_arg: RunArg, _: EnsembleConfig) -> LoadResult:
run_arg.ensemble_storage.state_map[run_arg.iens] = RealizationState.LOAD_FAILURE
return LoadResult(None, "")
14 changes: 5 additions & 9 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,12 @@
from ert.async_utils import get_event_loop
from ert.ensemble_evaluator import identifiers
from ert.job_queue import Driver, JobQueue
from ert.load_status import LoadResult

from .._wait_for_evaluator import wait_for_evaluator
from ._ensemble import Ensemble

if TYPE_CHECKING:
from ert.callbacks import Callback
from ert.config import AnalysisConfig, EnsembleConfig, QueueConfig
from ert.run_arg import RunArg
from ert.config import AnalysisConfig, QueueConfig

from ..config import EvaluatorServerConfig
from ._realization import Realization
Expand Down Expand Up @@ -137,12 +134,11 @@ def setup_timeout_callback(
timeout_queue: asyncio.Queue[MsgType],
cloudevent_unary_send: Callable[[MsgType], Awaitable[None]],
event_generator: Callable[[str, Optional[int]], MsgType],
) -> Tuple[Callback, asyncio.Task[None]]:
def on_timeout(run_args: RunArg, _: EnsembleConfig) -> LoadResult:
) -> Tuple[Callable[[int], None], asyncio.Task[None]]:
def on_timeout(iens: int) -> None:
timeout_queue.put_nowait(
event_generator(identifiers.EVTYPE_FM_STEP_TIMEOUT, run_args.iens)
event_generator(identifiers.EVTYPE_FM_STEP_TIMEOUT, iens)
)
return LoadResult(None, "timed out")

async def send_timeout_message() -> None:
while True:
Expand Down Expand Up @@ -215,7 +211,7 @@ async def evaluate_async(
) -> None:
self._config = config
await self._evaluate_inner(
cloudevent_unary_send=self.queue_cloudevent, # type: ignore
cloudevent_unary_send=self.queue_cloudevent,
output_bus=self.output_bus,
experiment_id=experiment_id,
)
Expand Down
8 changes: 3 additions & 5 deletions src/ert/ensemble_evaluator/_builder/_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

SOURCE_TEMPLATE_STEP = "/step/{step_id}"
if TYPE_CHECKING:
from ert.callbacks import Callback, CallbackArgs
from ert.config.ensemble_config import EnsembleConfig
from ert.run_arg import RunArg


Expand All @@ -26,11 +26,9 @@ class LegacyStep:
jobs: Sequence[LegacyJob]
name: str
max_runtime: Optional[int]
callback_arguments: CallbackArgs
done_callback: Callback
exit_callback: Callback
run_arg: "RunArg"
ensemble_config: "EnsembleConfig"
num_cpu: int
run_path: Path
job_script: str
job_name: str
run_arg: RunArg
38 changes: 20 additions & 18 deletions src/ert/job_queue/job_queue_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,24 @@
import random
import time
from threading import Lock, Semaphore, Thread
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Optional

from cwrap import BaseCClass
from ecl.util.util import StringList

from ert._clib.queue import _refresh_status # pylint: disable=import-error
from ert.callbacks import forward_model_ok
from ert.load_status import LoadStatus

from ..realization_state import RealizationState
from . import ResPrototype
from .job_status import JobStatus
from .submit_status import SubmitStatus
from .thread_status import ThreadStatus

if TYPE_CHECKING:
from ert.callbacks import Callback, CallbackArgs

from ..config import EnsembleConfig
from ..run_arg import RunArg
from .driver import Driver

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -94,16 +96,14 @@ def __init__(
num_cpu: int,
status_file: str,
exit_file: str,
done_callback_function: Callback,
exit_callback_function: Callback,
callback_arguments: CallbackArgs,
run_arg: "RunArg",
ensemble_config: "EnsembleConfig",
max_runtime: Optional[int] = None,
callback_timeout: Optional[Callback] = None,
callback_timeout: Optional[Callable[[int], None]] = None,
):
self.done_callback_function = done_callback_function
self.exit_callback_function = exit_callback_function
self.callback_timeout = callback_timeout
self.callback_arguments = callback_arguments
self.run_arg = run_arg
self.ensemble_config = ensemble_config
argc = 1
argv = StringList()
argv.append(run_path)
Expand Down Expand Up @@ -177,8 +177,8 @@ def submit(self, driver: "Driver") -> SubmitStatus:
return self._submit(driver)

def run_done_callback(self) -> Optional[LoadStatus]:
callback_status, status_msg = self.done_callback_function(
*self.callback_arguments
callback_status, status_msg = forward_model_ok(
self.run_arg, self.ensemble_config
)
if callback_status == LoadStatus.LOAD_SUCCESSFUL:
self._set_queue_status(JobStatus.SUCCESS)
Expand All @@ -194,10 +194,12 @@ def run_done_callback(self) -> Optional[LoadStatus]:

def run_timeout_callback(self) -> None:
if self.callback_timeout:
self.callback_timeout(*self.callback_arguments)
self.callback_timeout(self.run_arg.iens)

def run_exit_callback(self) -> None:
self.exit_callback_function(*self.callback_arguments)
self.run_arg.ensemble_storage.state_map[
self.run_arg.iens
] = RealizationState.LOAD_FAILURE

def is_running(self, given_status: Optional[JobStatus] = None) -> bool:
status = given_status or self.status
Expand Down Expand Up @@ -297,7 +299,7 @@ def _handle_end_status(
if end_status == JobStatus.DONE:
with pool_sema:
logger.info(
f"Realization: {self.callback_arguments[0].iens} complete, "
f"Realization: {self.run_arg.iens} complete, "
"starting to load results"
)
self.run_done_callback()
Expand All @@ -310,19 +312,19 @@ def _handle_end_status(
elif current_status in self.RESUBMIT_STATES:
if self.submit_attempt < max_submit:
logger.warning(
f"Realization: {self.callback_arguments[0].iens} "
f"Realization: {self.run_arg.iens} "
f"failed with: {self._status_msg}, resubmitting"
)
self._transition_status(ThreadStatus.READY, current_status)
else:
self._transition_to_failure(
message=f"Realization: {self.callback_arguments[0].iens} "
message=f"Realization: {self.run_arg.iens} "
"failed after reaching max submit"
f" ({max_submit}):\n\t{self._status_msg}"
)
elif current_status in self.FAILURE_STATES:
self._transition_to_failure(
message=f"Realization: {self.callback_arguments[0].iens} "
message=f"Realization: {self.run_arg.iens} "
f"failed with: {self._status_msg}"
)
else:
Expand Down
15 changes: 5 additions & 10 deletions src/ert/job_queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from . import ResPrototype

if TYPE_CHECKING:
from ert.callbacks import Callback
from ert.config import ErtConfig
from ert.ensemble_evaluator import LegacyStep
from ert.run_arg import RunArg
Expand Down Expand Up @@ -484,8 +483,6 @@ def add_job_from_run_arg(
run_arg: "RunArg",
ert_config: "ErtConfig",
max_runtime: Optional[int],
ok_cb: Callback,
exit_cb: Callback,
num_cpu: int,
) -> None:
job_name = run_arg.job_name
Expand All @@ -499,9 +496,8 @@ def add_job_from_run_arg(
num_cpu=num_cpu,
status_file=self.status_file,
exit_file=self.exit_file,
done_callback_function=ok_cb,
exit_callback_function=exit_cb,
callback_arguments=(run_arg, ert_config.ensemble_config),
run_arg=run_arg,
ensemble_config=ert_config.ensemble_config,
max_runtime=max_runtime,
)

Expand All @@ -512,7 +508,7 @@ def add_job_from_run_arg(
def add_ee_stage(
self,
stage: "LegacyStep",
callback_timeout: Optional[Callback] = None,
callback_timeout: Optional[Callable[[int], None]] = None,
) -> None:
job = JobQueueNode(
job_script=stage.job_script,
Expand All @@ -521,9 +517,8 @@ def add_ee_stage(
num_cpu=stage.num_cpu,
status_file=self.status_file,
exit_file=self.exit_file,
done_callback_function=stage.done_callback,
exit_callback_function=stage.exit_callback,
callback_arguments=stage.callback_arguments,
run_arg=stage.run_arg,
ensemble_config=stage.ensemble_config,
max_runtime=stage.max_runtime,
callback_timeout=callback_timeout,
)
Expand Down
12 changes: 3 additions & 9 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from cloudevents.http import CloudEvent

import _ert_com_protocol
from ert.callbacks import forward_model_exit, forward_model_ok
from ert.cli import MODULE_MODE
from ert.config import HookRuntime
from ert.enkf_main import EnKFMain
Expand Down Expand Up @@ -389,23 +388,18 @@ def _build_ensemble(
self.ert().resConfig().forward_model_list
)
]
real.active(True).add_step(
real.add_step(
LegacyStep(
id_="0",
jobs=jobs,
name="legacy step",
max_runtime=self.ert().analysisConfig().max_runtime,
callback_arguments=(
run_arg,
self.ert().resConfig().ensemble_config,
),
done_callback=forward_model_ok,
exit_callback=forward_model_exit,
run_arg=run_arg,
ensemble_config=self.ert().resConfig().ensemble_config,
num_cpu=self.ert().get_num_cpu(),
run_path=Path(run_arg.runpath),
job_script=self.ert().resConfig().queue_config.job_script,
job_name=run_arg.job_name,
run_arg=run_arg,
)
)
builder.add_realization(real)
Expand Down
3 changes: 0 additions & 3 deletions src/ert/simulator/simulation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from time import sleep
from typing import TYPE_CHECKING, Any, List, Optional, Tuple

from ert.callbacks import forward_model_exit, forward_model_ok
from ert.config import HookRuntime
from ert.job_queue import Driver, JobQueue, JobQueueManager, RunStatus
from ert.realization_state import RealizationState
Expand Down Expand Up @@ -49,8 +48,6 @@ def _run_forward_model(
run_arg,
ert.resConfig(),
max_runtime,
forward_model_ok,
forward_model_exit,
ert.get_num_cpu(),
)

Expand Down
16 changes: 12 additions & 4 deletions tests/unit_tests/ensemble_evaluator/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ert.ensemble_evaluator.config import EvaluatorServerConfig
from ert.ensemble_evaluator.evaluator import EnsembleEvaluator
from ert.ensemble_evaluator.snapshot import SnapshotBuilder
from ert.job_queue import JobQueueNode
from ert.load_status import LoadStatus

from .ensemble_evaluator_utils import TestEnsemble
Expand Down Expand Up @@ -66,7 +67,16 @@ def queue_config_fixture():

@pytest.fixture
def make_ensemble_builder(queue_config):
def _make_ensemble_builder(tmpdir, num_reals, num_jobs, job_sleep=0):
def _make_ensemble_builder(monkeypatch, tmpdir, num_reals, num_jobs, job_sleep=0):
monkeypatch.setattr(
ert.job_queue.job_queue_node,
"forward_model_ok",
lambda _, _b: (LoadStatus.LOAD_SUCCESSFUL, ""),
)
monkeypatch.setattr(
JobQueueNode, "run_exit_callback", lambda _: (LoadStatus.LOAD_FAILURE, "")
)

builder = ert.ensemble_evaluator.EnsembleBuilder()
with tmpdir.as_cwd():
ext_job_list = []
Expand Down Expand Up @@ -121,11 +131,9 @@ class RunArg:
job_script="job_dispatch.py",
max_runtime=10,
run_arg=Mock(iens=iens),
done_callback=lambda _, _b: (LoadStatus.LOAD_SUCCESSFUL, ""),
exit_callback=lambda _, _b: (LoadStatus.LOAD_FAILURE, ""),
# the first callback_argument is expected to be a run_arg
# from the run_arg, the queue wants to access the iens prop
callback_arguments=(RunArg(iens), None),
ensemble_config=None,
run_path=run_path,
num_cpu=1,
name="dummy step",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def __init__(self, _iter, reals, steps, jobs, id_):
],
name=f"step-{step_no}",
max_runtime=0,
callback_arguments=(),
done_callback=None,
exit_callback=None,
ensemble_config=None,
num_cpu=0,
run_path=None,
run_arg=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ async def _handler(websocket, path):
@pytest.mark.asyncio
@pytest.mark.timeout(60)
async def test_happy_path(
tmpdir, unused_tcp_port, event_loop, make_ensemble_builder, queue_config, caplog
tmpdir,
unused_tcp_port,
event_loop,
make_ensemble_builder,
queue_config,
caplog,
monkeypatch,
):
asyncio.set_event_loop(event_loop)
host = "localhost"
Expand All @@ -44,7 +50,7 @@ async def test_happy_path(
mock_ws_task = get_event_loop().create_task(mock_ws(host, unused_tcp_port, done))
await wait_for_evaluator(base_url=url, timeout=5)

ensemble = make_ensemble_builder(tmpdir, 1, 1).build()
ensemble = make_ensemble_builder(monkeypatch, tmpdir, 1, 1).build()
queue = JobQueue(
Driver.create_driver(queue_config), max_submit=queue_config.max_submit
)
Expand Down
4 changes: 1 addition & 3 deletions tests/unit_tests/ensemble_evaluator/test_ensemble_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def test_build_ensemble(active_real):
job_script="job_script",
job_name="job_name",
num_cpu=1,
callback_arguments=MagicMock(),
done_callback=MagicMock(),
exit_callback=MagicMock(),
ensemble_config=MagicMock(),
jobs=[
LegacyJob(
ext_job=MagicMock(),
Expand Down
Loading

0 comments on commit bc2f62f

Please sign in to comment.