From ef63ac03599bf90ceb4113f2372f62d9875f2120 Mon Sep 17 00:00:00 2001 From: Jon Holba Date: Tue, 19 Sep 2023 20:52:36 +0200 Subject: [PATCH] Remove callbacks from job_queue_node. Removed callbacks and split callback_arguments into run_arg and ensemble_config. Updated tests. Inlined exit callback into job_queue_node. --- src/ert/callbacks.py | 10 +--- .../ensemble_evaluator/_builder/_legacy.py | 12 ++--- src/ert/ensemble_evaluator/_builder/_step.py | 8 ++-- src/ert/job_queue/job_queue_node.py | 38 ++++++++------- src/ert/job_queue/queue.py | 15 ++---- src/ert/run_models/base_run_model.py | 12 ++--- src/ert/simulator/simulation_context.py | 3 -- .../unit_tests/ensemble_evaluator/conftest.py | 16 +++++-- .../ensemble_evaluator_utils.py | 4 +- .../test_async_queue_execution.py | 10 +++- .../test_ensemble_builder.py | 4 +- .../test_ensemble_legacy.py | 14 +++--- tests/unit_tests/job_queue/test_job_queue.py | 36 ++++++++------ .../job_queue/test_job_queue_manager.py | 47 +++++++++++-------- .../test_job_queue_manager_torque.py | 39 ++++++++++----- 15 files changed, 141 insertions(+), 127 deletions(-) diff --git a/src/ert/callbacks.py b/src/ert/callbacks.py index bca3cf73572..b3904dc9f5c 100644 --- a/src/ert/callbacks.py +++ b/src/ert/callbacks.py @@ -3,7 +3,7 @@ import logging import time from pathlib import Path -from typing import Callable, Iterable, Tuple +from typing import Iterable from ert.config import EnsembleConfig, ParameterConfig, SummaryConfig from ert.run_arg import RunArg @@ -11,9 +11,6 @@ from .load_status import LoadResult, LoadStatus from .realization_state import RealizationState -CallbackArgs = Tuple[RunArg, EnsembleConfig] -Callback = Callable[[RunArg, EnsembleConfig], LoadResult] - logger = logging.getLogger(__name__) @@ -102,8 +99,3 @@ def forward_model_ok( ) return final_result - - -def forward_model_exit(run_arg: RunArg, _: EnsembleConfig) -> LoadResult: - run_arg.ensemble_storage.state_map[run_arg.iens] = RealizationState.LOAD_FAILURE - return LoadResult(None, "") diff --git a/src/ert/ensemble_evaluator/_builder/_legacy.py b/src/ert/ensemble_evaluator/_builder/_legacy.py index 056313ea0b1..90fc556fa9d 100644 --- a/src/ert/ensemble_evaluator/_builder/_legacy.py +++ b/src/ert/ensemble_evaluator/_builder/_legacy.py @@ -25,15 +25,12 @@ from ert.async_utils import get_event_loop from ert.ensemble_evaluator import identifiers from ert.job_queue import Driver, JobQueue -from ert.load_status import LoadResult from .._wait_for_evaluator import wait_for_evaluator from ._ensemble import Ensemble if TYPE_CHECKING: - from ert.callbacks import Callback - from ert.config import AnalysisConfig, EnsembleConfig, QueueConfig - from ert.run_arg import RunArg + from ert.config import AnalysisConfig, QueueConfig from ..config import EvaluatorServerConfig from ._realization import Realization @@ -137,12 +134,11 @@ def setup_timeout_callback( timeout_queue: asyncio.Queue[MsgType], cloudevent_unary_send: Callable[[MsgType], Awaitable[None]], event_generator: Callable[[str, Optional[int]], MsgType], - ) -> Tuple[Callback, asyncio.Task[None]]: - def on_timeout(run_args: RunArg, _: EnsembleConfig) -> LoadResult: + ) -> Tuple[Callable[[int], None], asyncio.Task[None]]: + def on_timeout(iens: int) -> None: timeout_queue.put_nowait( - event_generator(identifiers.EVTYPE_FM_STEP_TIMEOUT, run_args.iens) + event_generator(identifiers.EVTYPE_FM_STEP_TIMEOUT, iens) ) - return LoadResult(None, "timed out") async def send_timeout_message() -> None: while True: diff --git a/src/ert/ensemble_evaluator/_builder/_step.py b/src/ert/ensemble_evaluator/_builder/_step.py index e97123d0c3f..c6eaedbcdb1 100644 --- a/src/ert/ensemble_evaluator/_builder/_step.py +++ b/src/ert/ensemble_evaluator/_builder/_step.py @@ -8,7 +8,7 @@ SOURCE_TEMPLATE_STEP = "/step/{step_id}" if TYPE_CHECKING: - from ert.callbacks import Callback, CallbackArgs + from ert.config.ensemble_config import EnsembleConfig from ert.run_arg import RunArg @@ -26,11 +26,9 @@ class LegacyStep: jobs: Sequence[LegacyJob] name: str max_runtime: Optional[int] - callback_arguments: CallbackArgs - done_callback: Callback - exit_callback: Callback + run_arg: "RunArg" + ensemble_config: "EnsembleConfig" num_cpu: int run_path: Path job_script: str job_name: str - run_arg: RunArg diff --git a/src/ert/job_queue/job_queue_node.py b/src/ert/job_queue/job_queue_node.py index 10e9c516300..67628ed177b 100644 --- a/src/ert/job_queue/job_queue_node.py +++ b/src/ert/job_queue/job_queue_node.py @@ -4,22 +4,24 @@ import random import time from threading import Lock, Semaphore, Thread -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Callable, Optional from cwrap import BaseCClass from ecl.util.util import StringList from ert._clib.queue import _refresh_status # pylint: disable=import-error +from ert.callbacks import forward_model_ok from ert.load_status import LoadStatus +from ..realization_state import RealizationState from . import ResPrototype from .job_status import JobStatus from .submit_status import SubmitStatus from .thread_status import ThreadStatus if TYPE_CHECKING: - from ert.callbacks import Callback, CallbackArgs - + from ..config import EnsembleConfig + from ..run_arg import RunArg from .driver import Driver logger = logging.getLogger(__name__) @@ -94,16 +96,14 @@ def __init__( num_cpu: int, status_file: str, exit_file: str, - done_callback_function: Callback, - exit_callback_function: Callback, - callback_arguments: CallbackArgs, + run_arg: "RunArg", + ensemble_config: "EnsembleConfig", max_runtime: Optional[int] = None, - callback_timeout: Optional[Callback] = None, + callback_timeout: Optional[Callable[[int], None]] = None, ): - self.done_callback_function = done_callback_function - self.exit_callback_function = exit_callback_function self.callback_timeout = callback_timeout - self.callback_arguments = callback_arguments + self.run_arg = run_arg + self.ensemble_config = ensemble_config argc = 1 argv = StringList() argv.append(run_path) @@ -177,8 +177,8 @@ def submit(self, driver: "Driver") -> SubmitStatus: return self._submit(driver) def run_done_callback(self) -> Optional[LoadStatus]: - callback_status, status_msg = self.done_callback_function( - *self.callback_arguments + callback_status, status_msg = forward_model_ok( + self.run_arg, self.ensemble_config ) if callback_status == LoadStatus.LOAD_SUCCESSFUL: self._set_queue_status(JobStatus.SUCCESS) @@ -194,10 +194,12 @@ def run_done_callback(self) -> Optional[LoadStatus]: def run_timeout_callback(self) -> None: if self.callback_timeout: - self.callback_timeout(*self.callback_arguments) + self.callback_timeout(self.run_arg.iens) def run_exit_callback(self) -> None: - self.exit_callback_function(*self.callback_arguments) + self.run_arg.ensemble_storage.state_map[ + self.run_arg.iens + ] = RealizationState.LOAD_FAILURE def is_running(self, given_status: Optional[JobStatus] = None) -> bool: status = given_status or self.status @@ -297,7 +299,7 @@ def _handle_end_status( if end_status == JobStatus.DONE: with pool_sema: logger.info( - f"Realization: {self.callback_arguments[0].iens} complete, " + f"Realization: {self.run_arg.iens} complete, " "starting to load results" ) self.run_done_callback() @@ -310,19 +312,19 @@ def _handle_end_status( elif current_status in self.RESUBMIT_STATES: if self.submit_attempt < max_submit: logger.warning( - f"Realization: {self.callback_arguments[0].iens} " + f"Realization: {self.run_arg.iens} " f"failed with: {self._status_msg}, resubmitting" ) self._transition_status(ThreadStatus.READY, current_status) else: self._transition_to_failure( - message=f"Realization: {self.callback_arguments[0].iens} " + message=f"Realization: {self.run_arg.iens} " "failed after reaching max submit" f" ({max_submit}):\n\t{self._status_msg}" ) elif current_status in self.FAILURE_STATES: self._transition_to_failure( - message=f"Realization: {self.callback_arguments[0].iens} " + message=f"Realization: {self.run_arg.iens} " f"failed with: {self._status_msg}" ) else: diff --git a/src/ert/job_queue/queue.py b/src/ert/job_queue/queue.py index 763f871a6fa..ecb9ffce3f8 100644 --- a/src/ert/job_queue/queue.py +++ b/src/ert/job_queue/queue.py @@ -41,7 +41,6 @@ from . import ResPrototype if TYPE_CHECKING: - from ert.callbacks import Callback from ert.config import ErtConfig from ert.ensemble_evaluator import LegacyStep from ert.run_arg import RunArg @@ -484,8 +483,6 @@ def add_job_from_run_arg( run_arg: "RunArg", ert_config: "ErtConfig", max_runtime: Optional[int], - ok_cb: Callback, - exit_cb: Callback, num_cpu: int, ) -> None: job_name = run_arg.job_name @@ -499,9 +496,8 @@ def add_job_from_run_arg( num_cpu=num_cpu, status_file=self.status_file, exit_file=self.exit_file, - done_callback_function=ok_cb, - exit_callback_function=exit_cb, - callback_arguments=(run_arg, ert_config.ensemble_config), + run_arg=run_arg, + ensemble_config=ert_config.ensemble_config, max_runtime=max_runtime, ) @@ -512,7 +508,7 @@ def add_job_from_run_arg( def add_ee_stage( self, stage: "LegacyStep", - callback_timeout: Optional[Callback] = None, + callback_timeout: Optional[Callable[[int], None]] = None, ) -> None: job = JobQueueNode( job_script=stage.job_script, @@ -521,9 +517,8 @@ def add_ee_stage( num_cpu=stage.num_cpu, status_file=self.status_file, exit_file=self.exit_file, - done_callback_function=stage.done_callback, - exit_callback_function=stage.exit_callback, - callback_arguments=stage.callback_arguments, + run_arg=stage.run_arg, + ensemble_config=stage.ensemble_config, max_runtime=stage.max_runtime, callback_timeout=callback_timeout, ) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index ee9a851862f..cc5a006b120 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -15,7 +15,6 @@ from cloudevents.http import CloudEvent import _ert_com_protocol -from ert.callbacks import forward_model_exit, forward_model_ok from ert.cli import MODULE_MODE from ert.config import HookRuntime from ert.enkf_main import EnKFMain @@ -389,23 +388,18 @@ def _build_ensemble( self.ert().resConfig().forward_model_list ) ] - real.active(True).add_step( + real.add_step( LegacyStep( id_="0", jobs=jobs, name="legacy step", max_runtime=self.ert().analysisConfig().max_runtime, - callback_arguments=( - run_arg, - self.ert().resConfig().ensemble_config, - ), - done_callback=forward_model_ok, - exit_callback=forward_model_exit, + run_arg=run_arg, + ensemble_config=self.ert().resConfig().ensemble_config, num_cpu=self.ert().get_num_cpu(), run_path=Path(run_arg.runpath), job_script=self.ert().resConfig().queue_config.job_script, job_name=run_arg.job_name, - run_arg=run_arg, ) ) builder.add_realization(real) diff --git a/src/ert/simulator/simulation_context.py b/src/ert/simulator/simulation_context.py index 327c95cafd1..1248289be6a 100644 --- a/src/ert/simulator/simulation_context.py +++ b/src/ert/simulator/simulation_context.py @@ -5,7 +5,6 @@ from time import sleep from typing import TYPE_CHECKING, Any, List, Optional, Tuple -from ert.callbacks import forward_model_exit, forward_model_ok from ert.config import HookRuntime from ert.job_queue import Driver, JobQueue, JobQueueManager, RunStatus from ert.realization_state import RealizationState @@ -49,8 +48,6 @@ def _run_forward_model( run_arg, ert.resConfig(), max_runtime, - forward_model_ok, - forward_model_exit, ert.get_num_cpu(), ) diff --git a/tests/unit_tests/ensemble_evaluator/conftest.py b/tests/unit_tests/ensemble_evaluator/conftest.py index 5d5b84a7417..7ff87201b32 100644 --- a/tests/unit_tests/ensemble_evaluator/conftest.py +++ b/tests/unit_tests/ensemble_evaluator/conftest.py @@ -12,6 +12,7 @@ from ert.ensemble_evaluator.config import EvaluatorServerConfig from ert.ensemble_evaluator.evaluator import EnsembleEvaluator from ert.ensemble_evaluator.snapshot import SnapshotBuilder +from ert.job_queue import JobQueueNode from ert.load_status import LoadStatus from .ensemble_evaluator_utils import TestEnsemble @@ -66,7 +67,16 @@ def queue_config_fixture(): @pytest.fixture def make_ensemble_builder(queue_config): - def _make_ensemble_builder(tmpdir, num_reals, num_jobs, job_sleep=0): + def _make_ensemble_builder(monkeypatch, tmpdir, num_reals, num_jobs, job_sleep=0): + monkeypatch.setattr( + ert.job_queue.job_queue_node, + "forward_model_ok", + lambda _, _b: (LoadStatus.LOAD_SUCCESSFUL, ""), + ) + monkeypatch.setattr( + JobQueueNode, "run_exit_callback", lambda _: (LoadStatus.LOAD_FAILURE, "") + ) + builder = ert.ensemble_evaluator.EnsembleBuilder() with tmpdir.as_cwd(): ext_job_list = [] @@ -121,11 +131,9 @@ class RunArg: job_script="job_dispatch.py", max_runtime=10, run_arg=Mock(iens=iens), - done_callback=lambda _, _b: (LoadStatus.LOAD_SUCCESSFUL, ""), - exit_callback=lambda _, _b: (LoadStatus.LOAD_FAILURE, ""), # the first callback_argument is expected to be a run_arg # from the run_arg, the queue wants to access the iens prop - callback_arguments=(RunArg(iens), None), + ensemble_config=None, run_path=run_path, num_cpu=1, name="dummy step", diff --git a/tests/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py b/tests/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py index f75229cc41b..fb7f7cfd52d 100644 --- a/tests/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py +++ b/tests/unit_tests/ensemble_evaluator/ensemble_evaluator_utils.py @@ -71,9 +71,7 @@ def __init__(self, _iter, reals, steps, jobs, id_): ], name=f"step-{step_no}", max_runtime=0, - callback_arguments=(), - done_callback=None, - exit_callback=None, + ensemble_config=None, num_cpu=0, run_path=None, run_arg=None, diff --git a/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py b/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py index 65b622e4fe8..60a41b4d490 100644 --- a/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py +++ b/tests/unit_tests/ensemble_evaluator/test_async_queue_execution.py @@ -34,7 +34,13 @@ async def _handler(websocket, path): @pytest.mark.asyncio @pytest.mark.timeout(60) async def test_happy_path( - tmpdir, unused_tcp_port, event_loop, make_ensemble_builder, queue_config, caplog + tmpdir, + unused_tcp_port, + event_loop, + make_ensemble_builder, + queue_config, + caplog, + monkeypatch, ): asyncio.set_event_loop(event_loop) host = "localhost" @@ -44,7 +50,7 @@ async def test_happy_path( mock_ws_task = get_event_loop().create_task(mock_ws(host, unused_tcp_port, done)) await wait_for_evaluator(base_url=url, timeout=5) - ensemble = make_ensemble_builder(tmpdir, 1, 1).build() + ensemble = make_ensemble_builder(monkeypatch, tmpdir, 1, 1).build() queue = JobQueue( Driver.create_driver(queue_config), max_submit=queue_config.max_submit ) diff --git a/tests/unit_tests/ensemble_evaluator/test_ensemble_builder.py b/tests/unit_tests/ensemble_evaluator/test_ensemble_builder.py index e13bbcc8868..2f9ef599a46 100644 --- a/tests/unit_tests/ensemble_evaluator/test_ensemble_builder.py +++ b/tests/unit_tests/ensemble_evaluator/test_ensemble_builder.py @@ -29,9 +29,7 @@ def test_build_ensemble(active_real): job_script="job_script", job_name="job_name", num_cpu=1, - callback_arguments=MagicMock(), - done_callback=MagicMock(), - exit_callback=MagicMock(), + ensemble_config=MagicMock(), jobs=[ LegacyJob( ext_job=MagicMock(), diff --git a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py index 3c6c0295529..fda65f1c138 100644 --- a/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py +++ b/tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py @@ -12,11 +12,11 @@ @pytest.mark.timeout(60) -def test_run_legacy_ensemble(tmpdir, make_ensemble_builder): +def test_run_legacy_ensemble(tmpdir, make_ensemble_builder, monkeypatch): num_reals = 2 custom_port_range = range(1024, 65535) with tmpdir.as_cwd(): - ensemble = make_ensemble_builder(tmpdir, num_reals, 2).build() + ensemble = make_ensemble_builder(monkeypatch, tmpdir, num_reals, 2).build() config = EvaluatorServerConfig( custom_port_range=custom_port_range, custom_host="127.0.0.1", @@ -44,11 +44,13 @@ def test_run_legacy_ensemble(tmpdir, make_ensemble_builder): @pytest.mark.timeout(60) -def test_run_and_cancel_legacy_ensemble(tmpdir, make_ensemble_builder): +def test_run_and_cancel_legacy_ensemble(tmpdir, make_ensemble_builder, monkeypatch): num_reals = 2 custom_port_range = range(1024, 65535) with tmpdir.as_cwd(): - ensemble = make_ensemble_builder(tmpdir, num_reals, 2, job_sleep=30).build() + ensemble = make_ensemble_builder( + monkeypatch, tmpdir, num_reals, 2, job_sleep=30 + ).build() config = EvaluatorServerConfig( custom_port_range=custom_port_range, custom_host="127.0.0.1", @@ -77,11 +79,11 @@ def test_run_and_cancel_legacy_ensemble(tmpdir, make_ensemble_builder): @pytest.mark.timeout(60) -def test_run_legacy_ensemble_exception(tmpdir, make_ensemble_builder): +def test_run_legacy_ensemble_exception(tmpdir, make_ensemble_builder, monkeypatch): num_reals = 2 custom_port_range = range(1024, 65535) with tmpdir.as_cwd(): - ensemble = make_ensemble_builder(tmpdir, num_reals, 2).build() + ensemble = make_ensemble_builder(monkeypatch, tmpdir, num_reals, 2).build() config = EvaluatorServerConfig( custom_port_range=custom_port_range, custom_host="127.0.0.1", diff --git a/tests/unit_tests/job_queue/test_job_queue.py b/tests/unit_tests/job_queue/test_job_queue.py index 59a5879fca1..dba7b10f551 100644 --- a/tests/unit_tests/job_queue/test_job_queue.py +++ b/tests/unit_tests/job_queue/test_job_queue.py @@ -4,16 +4,14 @@ from dataclasses import dataclass from pathlib import Path from threading import BoundedSemaphore -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional from unittest.mock import MagicMock, patch +import ert.callbacks from ert.config import QueueSystem from ert.job_queue import Driver, JobQueue, JobQueueNode, JobStatus from ert.load_status import LoadStatus -if TYPE_CHECKING: - from ert.callbacks import Callback - def wait_for( func: Callable, target: Any = True, interval: float = 0.1, timeout: float = 30 @@ -65,11 +63,19 @@ class RunArg: def create_local_queue( + monkeypatch, executable_script: str, max_submit: int = 1, max_runtime: Optional[int] = None, - callback_timeout: Optional["Callback"] = None, + callback_timeout: Optional["Callable[[int], None]"] = None, ): + monkeypatch.setattr( + ert.job_queue.job_queue_node, "forward_model_ok", DUMMY_CONFIG["ok_callback"] + ) + monkeypatch.setattr( + JobQueueNode, "run_exit_callback", DUMMY_CONFIG["exit_callback"] + ) + driver = Driver(driver_type=QueueSystem.LOCAL) job_queue = JobQueue(driver, max_submit=max_submit) @@ -86,9 +92,8 @@ def create_local_queue( num_cpu=DUMMY_CONFIG["num_cpu"], status_file=job_queue.status_file, exit_file=job_queue.exit_file, - done_callback_function=DUMMY_CONFIG["ok_callback"], - exit_callback_function=DUMMY_CONFIG["exit_callback"], - callback_arguments=(RunArg(iens), None), + run_arg=RunArg(iens), + ensemble_config=None, max_runtime=max_runtime, callback_timeout=callback_timeout, ) @@ -107,7 +112,7 @@ def start_all(job_queue, sema_pool): def test_kill_jobs(tmpdir, monkeypatch): monkeypatch.chdir(tmpdir) - job_queue = create_local_queue(NEVER_ENDING_SCRIPT) + job_queue = create_local_queue(monkeypatch, NEVER_ENDING_SCRIPT) assert job_queue.queue_size == 10 assert job_queue.is_active() @@ -138,7 +143,7 @@ def test_kill_jobs(tmpdir, monkeypatch): def test_add_jobs(tmpdir, monkeypatch): monkeypatch.chdir(tmpdir) - job_queue = create_local_queue(SIMPLE_SCRIPT) + job_queue = create_local_queue(monkeypatch, SIMPLE_SCRIPT) assert job_queue.queue_size == 10 assert job_queue.is_active() @@ -158,7 +163,7 @@ def test_add_jobs(tmpdir, monkeypatch): def test_failing_jobs(tmpdir, monkeypatch): monkeypatch.chdir(tmpdir) - job_queue = create_local_queue(FAILING_SCRIPT, max_submit=1) + job_queue = create_local_queue(monkeypatch, FAILING_SCRIPT, max_submit=1) assert job_queue.queue_size == 10 assert job_queue.is_active() @@ -186,11 +191,12 @@ def test_timeout_jobs(tmpdir, monkeypatch): monkeypatch.chdir(tmpdir) job_numbers = set() - def callback(runarg, _): + def callback(iens): nonlocal job_numbers - job_numbers.add(runarg.iens) + job_numbers.add(iens) job_queue = create_local_queue( + monkeypatch, NEVER_ENDING_SCRIPT, max_submit=1, max_runtime=5, @@ -225,7 +231,7 @@ def callback(runarg, _): def test_add_dispatch_info(tmpdir, monkeypatch): monkeypatch.chdir(tmpdir) - job_queue = create_local_queue(SIMPLE_SCRIPT) + job_queue = create_local_queue(monkeypatch, SIMPLE_SCRIPT) ens_id = "some_id" cert = "My very nice cert" token = "my_super_secret_token" @@ -256,7 +262,7 @@ def test_add_dispatch_info(tmpdir, monkeypatch): def test_add_dispatch_info_cert_none(tmpdir, monkeypatch): monkeypatch.chdir(tmpdir) - job_queue = create_local_queue(SIMPLE_SCRIPT) + job_queue = create_local_queue(monkeypatch, SIMPLE_SCRIPT) ens_id = "some_id" dispatch_url = "wss://example.org" cert = None diff --git a/tests/unit_tests/job_queue/test_job_queue_manager.py b/tests/unit_tests/job_queue/test_job_queue_manager.py index 04606f85241..fbf7c963f9b 100644 --- a/tests/unit_tests/job_queue/test_job_queue_manager.py +++ b/tests/unit_tests/job_queue/test_job_queue_manager.py @@ -7,6 +7,7 @@ import pytest +import ert.callbacks from ert.config import QueueSystem from ert.job_queue import Driver, JobQueue, JobQueueManager, JobQueueNode, JobStatus from ert.load_status import LoadStatus @@ -27,13 +28,11 @@ class Config(TypedDict): def dummy_ok_callback(runarg, path): - print(f"success {runarg}, {path}") (Path(path) / "OK").write_text("success", encoding="utf-8") return (LoadStatus.LOAD_SUCCESSFUL, "") -def dummy_exit_callback(runarg, path): - print(f"failure {runarg} {path}") +def dummy_exit_callback(self): Path("ERROR").write_text("failure", encoding="utf-8") @@ -68,8 +67,15 @@ def dummy_exit_callback(runarg, path): def create_local_queue( - executable_script: str, max_submit: int = 2, num_realizations: int = 10 + monkeypatch, executable_script: str, max_submit: int = 2, num_realizations: int = 10 ): + monkeypatch.setattr( + ert.job_queue.job_queue_node, "forward_model_ok", DUMMY_CONFIG["ok_callback"] + ) + monkeypatch.setattr( + JobQueueNode, "run_exit_callback", DUMMY_CONFIG["exit_callback"] + ) + driver = Driver(driver_type=QueueSystem.LOCAL) job_queue = JobQueue(driver, max_submit=max_submit) @@ -86,12 +92,8 @@ def create_local_queue( num_cpu=DUMMY_CONFIG["num_cpu"], status_file=job_queue.status_file, exit_file=job_queue.exit_file, - done_callback_function=DUMMY_CONFIG["ok_callback"], - exit_callback_function=DUMMY_CONFIG["exit_callback"], - callback_arguments=[ - RunArg(iens), - Path(DUMMY_CONFIG["run_path"].format(iens)).resolve(), - ], + run_arg=RunArg(iens), + ensemble_config=Path(DUMMY_CONFIG["run_path"].format(iens)).resolve(), ) job_queue.add_job(job, iens) return job_queue @@ -100,6 +102,12 @@ def create_local_queue( def test_num_cpu_submitted_correctly_lsf(tmpdir, monkeypatch): """Assert that num_cpu from the ERT configuration is passed on to the bsub command used to submit jobs to LSF""" + monkeypatch.setattr( + ert.job_queue.job_queue_node, "forward_model_ok", DUMMY_CONFIG["ok_callback"] + ) + monkeypatch.setattr( + JobQueueNode, "run_exit_callback", DUMMY_CONFIG["exit_callback"] + ) monkeypatch.chdir(tmpdir) os.putenv("PATH", os.getcwd() + ":" + os.getenv("PATH")) driver = Driver(driver_type=QueueSystem.LSF) @@ -123,12 +131,8 @@ def test_num_cpu_submitted_correctly_lsf(tmpdir, monkeypatch): num_cpu=4, status_file="STATUS", exit_file="ERROR", - done_callback_function=DUMMY_CONFIG["ok_callback"], - exit_callback_function=DUMMY_CONFIG["exit_callback"], - callback_arguments=[ - RunArg(iens=job_id), - Path(DUMMY_CONFIG["run_path"].format(job_id)).resolve(), - ], + run_arg=RunArg(iens=job_id), + ensemble_config=Path(DUMMY_CONFIG["run_path"].format(job_id)).resolve(), ) pool_sema = BoundedSemaphore(value=2) @@ -151,7 +155,7 @@ def test_num_cpu_submitted_correctly_lsf(tmpdir, monkeypatch): def test_execute_queue(tmpdir, monkeypatch): monkeypatch.chdir(tmpdir) - job_queue = create_local_queue(SIMPLE_SCRIPT) + job_queue = create_local_queue(monkeypatch, SIMPLE_SCRIPT) manager = JobQueueManager(job_queue) manager.execute_queue() @@ -166,7 +170,10 @@ def test_max_submit_reached(tmpdir, max_submit_num, monkeypatch): monkeypatch.chdir(tmpdir) num_realizations = 2 job_queue = create_local_queue( - FAILING_SCRIPT, max_submit=max_submit_num, num_realizations=num_realizations + monkeypatch, + FAILING_SCRIPT, + max_submit=max_submit_num, + num_realizations=num_realizations, ) manager = JobQueueManager(job_queue) @@ -188,7 +195,9 @@ def test_max_submit_reached(tmpdir, max_submit_num, monkeypatch): @pytest.mark.parametrize("max_submit_num", [1, 2, 3]) def test_kill_queue(tmpdir, max_submit_num, monkeypatch): monkeypatch.chdir(tmpdir) - job_queue = create_local_queue(SIMPLE_SCRIPT, max_submit=max_submit_num) + job_queue = create_local_queue( + monkeypatch, SIMPLE_SCRIPT, max_submit=max_submit_num + ) manager = JobQueueManager(job_queue) job_queue.kill_all_jobs() manager.execute_queue() diff --git a/tests/unit_tests/job_queue/test_job_queue_manager_torque.py b/tests/unit_tests/job_queue/test_job_queue_manager_torque.py index 903b6748120..a0b2fd255c2 100644 --- a/tests/unit_tests/job_queue/test_job_queue_manager_torque.py +++ b/tests/unit_tests/job_queue/test_job_queue_manager_torque.py @@ -7,6 +7,7 @@ import pytest +import ert.job_queue.job_queue_node from ert.config import QueueSystem from ert.job_queue import Driver, JobQueueNode, JobStatus from ert.load_status import LoadStatus @@ -142,7 +143,14 @@ def _deploy_script(scriptname: Path, scripttext: str): script.chmod(stat.S_IRWXU) -def _build_jobqueuenode(dummy_config: JobConfig, job_id=0): +def _build_jobqueuenode(monkeypatch, dummy_config: JobConfig, job_id=0): + monkeypatch.setattr( + ert.job_queue.job_queue_node, "forward_model_ok", dummy_config["ok_callback"] + ) + monkeypatch.setattr( + JobQueueNode, "run_exit_callback", dummy_config["exit_callback"] + ) + runpath = Path(dummy_config["run_path"].format(job_id)) runpath.mkdir() @@ -153,12 +161,8 @@ def _build_jobqueuenode(dummy_config: JobConfig, job_id=0): num_cpu=1, status_file="STATUS", exit_file="ERROR", - done_callback_function=dummy_config["ok_callback"], - exit_callback_function=dummy_config["exit_callback"], - callback_arguments=[ - RunArg(iens=job_id), - Path(dummy_config["run_path"].format(job_id)).resolve(), - ], + run_arg=RunArg(iens=job_id), + ensemble_config=Path(dummy_config["run_path"].format(job_id)).resolve(), ) return (job, runpath) @@ -179,7 +183,7 @@ def _build_jobqueuenode(dummy_config: JobConfig, job_id=0): ], ) def test_run_torque_job( - temp_working_directory, dummy_config, qsub_script, qstat_script + monkeypatch, temp_working_directory, dummy_config, qsub_script, qstat_script ): """Verify that the torque driver will succeed in submitting and monitoring torque jobs even when the Torque commands qsub and qstat @@ -197,7 +201,7 @@ def test_run_torque_job( options=[("QSTAT_CMD", temp_working_directory / "qstat")], ) - (job, runpath) = _build_jobqueuenode(dummy_config) + (job, runpath) = _build_jobqueuenode(monkeypatch, dummy_config) job.run(driver, BoundedSemaphore()) job.wait_for() @@ -214,7 +218,11 @@ def test_run_torque_job( [("", "-f 10001"), ("-x", "-f -x 10001"), ("-f", "-f -f 10001")], ) def test_that_torque_driver_passes_options_to_qstat( - temp_working_directory, dummy_config, user_qstat_option, expected_options + monkeypatch, + temp_working_directory, + dummy_config, + user_qstat_option, + expected_options, ): """The driver supports setting options to qstat, but the hard-coded -f option is always there.""" @@ -237,7 +245,7 @@ def test_that_torque_driver_passes_options_to_qstat( ], ) - job, _runpath = _build_jobqueuenode(dummy_config) + job, _runpath = _build_jobqueuenode(monkeypatch, dummy_config) job.run(driver, BoundedSemaphore()) job.wait_for() @@ -256,7 +264,12 @@ def test_that_torque_driver_passes_options_to_qstat( ], ) def test_torque_job_status_from_qstat_output( - temp_working_directory, dummy_config, job_state, exit_status, expected_status + monkeypatch, + temp_working_directory, + dummy_config, + job_state, + exit_status, + expected_status, ): _deploy_script(dummy_config["job_script"], SIMPLE_SCRIPT) _deploy_script("qsub", MOCK_QSUB) @@ -271,7 +284,7 @@ def test_torque_job_status_from_qstat_output( options=[("QSTAT_CMD", temp_working_directory / "qstat")], ) - job, _runpath = _build_jobqueuenode(dummy_config) + job, _runpath = _build_jobqueuenode(monkeypatch, dummy_config) pool_sema = BoundedSemaphore(value=2) job.run(driver, pool_sema)