diff --git a/src/ert/config/ert_config.py b/src/ert/config/ert_config.py index cb73a1c51c4..a474cd688fb 100644 --- a/src/ert/config/ert_config.py +++ b/src/ert/config/ert_config.py @@ -54,6 +54,7 @@ ForwardModelStepKeys, HistorySource, HookRuntime, + QueueSystemWithGeneric, init_forward_model_schema, init_site_config_schema, init_user_config_schema, @@ -260,6 +261,7 @@ class ErtConfig: DEFAULT_RUNPATH_FILE: ClassVar[str] = ".ert_runpath_list" PREINSTALLED_FORWARD_MODEL_STEPS: ClassVar[Dict[str, ForwardModelStep]] = {} ENV_PR_FM_STEP: ClassVar[Dict[str, Dict[str, Any]]] = {} + ACTIVATE_SCRIPT: Optional[str] = None substitutions: Substitutions = field(default_factory=Substitutions) ensemble_config: EnsembleConfig = field(default_factory=EnsembleConfig) @@ -347,6 +349,7 @@ class ErtConfigWithPlugins(ErtConfig): Dict[str, ForwardModelStepPlugin] ] = preinstalled_fm_steps ENV_PR_FM_STEP: ClassVar[Dict[str, Dict[str, Any]]] = env_pr_fm_step + ACTIVATE_SCRIPT = ErtPluginManager().activate_script() assert issubclass(ErtConfigWithPlugins, ErtConfig) return ErtConfigWithPlugins @@ -675,6 +678,12 @@ def _merge_user_and_site_config( user_config_dict[keyword] = value + original_entries elif keyword not in user_config_dict: user_config_dict[keyword] = value + if cls.ACTIVATE_SCRIPT: + if "QUEUE_OPTION" not in user_config_dict: + user_config_dict["QUEUE_OPTION"] = [] + user_config_dict["QUEUE_OPTION"].append( + [QueueSystemWithGeneric.GENERIC, "ACTIVATE_SCRIPT", cls.ACTIVATE_SCRIPT] + ) return user_config_dict @classmethod diff --git a/src/ert/config/queue_config.py b/src/ert/config/queue_config.py index 61a76f68acc..576e0f8934b 100644 --- a/src/ert/config/queue_config.py +++ b/src/ert/config/queue_config.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import os import re import shutil from abc import abstractmethod @@ -26,12 +27,20 @@ NonEmptyString = Annotated[str, pydantic.StringConstraints(min_length=1)] +def activate_script() -> str: + venv = os.environ.get("VIRTUAL_ENV") + if not venv: + return "" + return f"source {venv}/bin/activate" + + @pydantic.dataclasses.dataclass(config={"extra": "forbid", "validate_assignment": True}) class QueueOptions: name: str max_running: pydantic.NonNegativeInt = 0 submit_sleep: pydantic.NonNegativeFloat = 0.0 project_code: Optional[str] = None + activate_script: str = field(default_factory=activate_script) @staticmethod def create_queue_options( @@ -292,7 +301,6 @@ def from_dict(cls, config_dict: ConfigDict) -> QueueConfig: _grouped_queue_options = _group_queue_options_by_queue_system( _raw_queue_options ) - _log_duplicated_queue_options(_raw_queue_options) _raise_for_defaulted_invalid_options(_raw_queue_options) diff --git a/src/ert/plugins/__init__.py b/src/ert/plugins/__init__.py index 584cb1dc8d6..0e5c5688b66 100644 --- a/src/ert/plugins/__init__.py +++ b/src/ert/plugins/__init__.py @@ -33,6 +33,7 @@ def inner(*args: P.args, **kwargs: P.kwargs) -> Any: "flow_config_path", "help_links", "site_config_lines", + "activate_script", ] and res is not None ): diff --git a/src/ert/plugins/hook_specifications/__init__.py b/src/ert/plugins/hook_specifications/__init__.py index b3740bcf23a..ef23b588eb6 100644 --- a/src/ert/plugins/hook_specifications/__init__.py +++ b/src/ert/plugins/hook_specifications/__init__.py @@ -1,3 +1,4 @@ +from .activate_script import activate_script from .ecl_config import ( ecl100_config_path, ecl300_config_path, @@ -21,6 +22,7 @@ from .site_config import site_config_lines __all__ = [ + "activate_script", "add_log_handle_to_root", "add_span_processor", "ecl100_config_path", diff --git a/src/ert/plugins/hook_specifications/activate_script.py b/src/ert/plugins/hook_specifications/activate_script.py new file mode 100644 index 00000000000..be44a620941 --- /dev/null +++ b/src/ert/plugins/hook_specifications/activate_script.py @@ -0,0 +1,19 @@ +from ert.plugins.plugin_manager import hook_specification + + +@hook_specification +def activate_script() -> str: # type: ignore + """ + Allows the plugin to provide a script that will be run when + the driver submits to the cluster. The script will run in + bash. + + Example: + import ert + + @ert.plugin(name="my_plugin") + def activate_script(): + return "source /private/venv/my_env/bin/activate + + :return: Activate script + """ diff --git a/src/ert/plugins/plugin_manager.py b/src/ert/plugins/plugin_manager.py index c4fb297de96..d78f996e6d4 100644 --- a/src/ert/plugins/plugin_manager.py +++ b/src/ert/plugins/plugin_manager.py @@ -198,6 +198,17 @@ def _site_config_lines(self) -> List[str]: ] return list(chain.from_iterable(reversed(plugin_site_config_lines))) + def activate_script(self) -> str: + plugin_responses = self.hook.activate_script() + if not plugin_responses: + return "" + if len(plugin_responses) > 1: + raise ValueError( + f"Only one activate script is allowed, got {[plugin.plugin_metadata.plugin_name for plugin in plugin_responses]}" + ) + else: + return plugin_responses[0].data + def get_installable_workflow_jobs(self) -> Dict[str, str]: config_workflow_jobs = self._get_config_workflow_jobs() hooked_workflow_jobs = self.get_ertscript_workflows().get_workflows() diff --git a/src/ert/scheduler/driver.py b/src/ert/scheduler/driver.py index 7740607f43c..9e16623cfa3 100644 --- a/src/ert/scheduler/driver.py +++ b/src/ert/scheduler/driver.py @@ -13,6 +13,17 @@ """Bash and other shells add an offset of 128 to the signal value when a process exited due to a signal""" +def create_submit_script( + runpath: Path, executable: str, args: tuple[str, ...], activate_script: str +) -> str: + return ( + "#!/usr/bin/env bash\n" + f"cd {shlex.quote(str(runpath))}\n" + f"{activate_script}\n" + f"exec -a {shlex.quote(executable)} {executable} {shlex.join(args)}\n" + ) + + class FailedSubmit(RuntimeError): pass @@ -20,9 +31,10 @@ class FailedSubmit(RuntimeError): class Driver(ABC): """Adapter for the HPC cluster.""" - def __init__(self, **kwargs: Dict[str, str]) -> None: + def __init__(self, activate_script: str = "") -> None: self._event_queue: Optional[asyncio.Queue[Event]] = None self._job_error_message_by_iens: Dict[int, str] = {} + self.activate_script = activate_script @property def event_queue(self) -> asyncio.Queue[Event]: diff --git a/src/ert/scheduler/local_driver.py b/src/ert/scheduler/local_driver.py index 49cc9c5f0bc..b17652c8819 100644 --- a/src/ert/scheduler/local_driver.py +++ b/src/ert/scheduler/local_driver.py @@ -34,6 +34,7 @@ async def submit( runpath: Optional[Path] = None, num_cpu: Optional[int] = 1, realization_memory: Optional[int] = 0, + activate_script: str = "", ) -> None: self._tasks[iens] = asyncio.create_task(self._run(iens, executable, *args)) with suppress(KeyError): diff --git a/src/ert/scheduler/lsf_driver.py b/src/ert/scheduler/lsf_driver.py index 1bd04bee2b9..40ad0dd03d8 100644 --- a/src/ert/scheduler/lsf_driver.py +++ b/src/ert/scheduler/lsf_driver.py @@ -27,7 +27,7 @@ get_args, ) -from .driver import SIGNAL_OFFSET, Driver, FailedSubmit +from .driver import SIGNAL_OFFSET, Driver, FailedSubmit, create_submit_script from .event import Event, FinishedEvent, StartedEvent _POLL_PERIOD = 2.0 # seconds @@ -264,8 +264,9 @@ def __init__( bjobs_cmd: Optional[str] = None, bkill_cmd: Optional[str] = None, bhist_cmd: Optional[str] = None, + activate_script: str = "", ) -> None: - super().__init__() + super().__init__(activate_script) self._queue_name = queue_name self._project_code = project_code self._resource_requirement = resource_requirement @@ -309,12 +310,7 @@ async def submit( arg_queue_name = ["-q", self._queue_name] if self._queue_name else [] arg_project_code = ["-P", self._project_code] if self._project_code else [] - - script = ( - "#!/usr/bin/env bash\n" - f"cd {shlex.quote(str(runpath))}\n" - f"exec -a {shlex.quote(executable)} {executable} {shlex.join(args)}\n" - ) + script = create_submit_script(runpath, executable, args, self.activate_script) script_path: Optional[Path] = None try: with NamedTemporaryFile( diff --git a/src/ert/scheduler/openpbs_driver.py b/src/ert/scheduler/openpbs_driver.py index dbf75ae3199..2b17e066fef 100644 --- a/src/ert/scheduler/openpbs_driver.py +++ b/src/ert/scheduler/openpbs_driver.py @@ -23,7 +23,7 @@ get_type_hints, ) -from .driver import Driver, FailedSubmit +from .driver import Driver, FailedSubmit, create_submit_script from .event import Event, FinishedEvent, StartedEvent logger = logging.getLogger(__name__) @@ -145,8 +145,9 @@ def __init__( qsub_cmd: Optional[str] = None, qstat_cmd: Optional[str] = None, qdel_cmd: Optional[str] = None, + activate_script: str = "", ) -> None: - super().__init__() + super().__init__(activate_script) self._queue_name = queue_name self._project_code = project_code @@ -241,11 +242,7 @@ async def submit( [] if self._keep_qsub_output else ["-o", "/dev/null", "-e", "/dev/null"] ) - script = ( - "#!/usr/bin/env bash\n" - f"cd {shlex.quote(str(runpath))}\n" - f"exec -a {shlex.quote(executable)} {executable} {shlex.join(args)}\n" - ) + script = create_submit_script(runpath, executable, args, self.activate_script) name_prefix = self._job_prefix or "" qsub_with_args: List[str] = [ str(self._qsub_cmd), diff --git a/src/ert/scheduler/slurm_driver.py b/src/ert/scheduler/slurm_driver.py index 2637a5cc919..1e0c06d4791 100644 --- a/src/ert/scheduler/slurm_driver.py +++ b/src/ert/scheduler/slurm_driver.py @@ -17,7 +17,7 @@ Tuple, ) -from .driver import SIGNAL_OFFSET, Driver, FailedSubmit +from .driver import SIGNAL_OFFSET, Driver, FailedSubmit, create_submit_script from .event import Event, FinishedEvent, StartedEvent SLURM_FAILED_EXIT_CODE_FETCH = SIGNAL_OFFSET + 66 @@ -77,6 +77,7 @@ def __init__( max_runtime: Optional[float] = None, squeue_timeout: float = 2, project_code: Optional[str] = None, + activate_script: str = "", ) -> None: """ The arguments "memory" and "realization_memory" are currently both @@ -90,7 +91,7 @@ def __init__( zero "realization memory" is the default and means no intended memory allocation. """ - super().__init__() + super().__init__(activate_script) self._submit_locks: dict[int, asyncio.Lock] = {} self._iens2jobid: dict[int, str] = {} self._jobs: dict[str, JobData] = {} @@ -181,11 +182,7 @@ async def submit( if runpath is None: runpath = Path.cwd() - script = ( - "#!/usr/bin/env bash\n" - f"cd {shlex.quote(str(runpath))}\n" - f"exec -a {shlex.quote(executable)} {executable} {shlex.join(args)}\n" - ) + script = create_submit_script(runpath, executable, args, self.activate_script) script_path: Optional[Path] = None try: with NamedTemporaryFile( diff --git a/src/everest/detached/__init__.py b/src/everest/detached/__init__.py index ff8c1be0de7..c383a6e946c 100644 --- a/src/everest/detached/__init__.py +++ b/src/everest/detached/__init__.py @@ -1,3 +1,4 @@ +import asyncio import importlib import json import logging @@ -21,6 +22,7 @@ SlurmQueueOptions, TorqueQueueOptions, ) +from ert.plugins import ErtPluginManager from ert.scheduler import create_driver from ert.scheduler.driver import Driver, FailedSubmit from ert.scheduler.event import StartedEvent @@ -57,12 +59,15 @@ async def start_server(config: EverestConfig, debug: bool = False) -> Driver: args = ["--config-file", str(config.config_path)] if debug: args.append("--debug") + poll_task = asyncio.create_task(driver.poll(), name="poll_task") await driver.submit(0, "everserver", *args) except FailedSubmit as err: raise ValueError(f"Failed to submit Everserver with error: {err}") from err status = await driver.event_queue.get() if not isinstance(status, StartedEvent): + poll_task.cancel() raise ValueError(f"Everserver not started as expected, got status: {status}") + poll_task.cancel() return driver @@ -279,22 +284,24 @@ def get_server_queue_options( simulator: Optional[SimulatorConfig], server: Optional[ServerConfig], ) -> QueueOptions: + activate_script = ErtPluginManager().activate_script() queue_system = _find_res_queue_system(simulator, server) ever_queue_config = server if server is not None else simulator - if queue_system == QueueSystem.LSF: queue = LsfQueueOptions( + activate_script=activate_script, lsf_queue=ever_queue_config.name, lsf_resource=ever_queue_config.options, ) elif queue_system == QueueSystem.SLURM: queue = SlurmQueueOptions( + activate_script=activate_script, exclude_host=ever_queue_config.exclude_host, include_host=ever_queue_config.include_host, partition=ever_queue_config.name, ) elif queue_system == QueueSystem.TORQUE: - queue = TorqueQueueOptions() + queue = TorqueQueueOptions(activate_script=activate_script) elif queue_system == QueueSystem.LOCAL: queue = LocalQueueOptions() else: diff --git a/tests/ert/unit_tests/config/test_queue_config.py b/tests/ert/unit_tests/config/test_queue_config.py index 73e0ad7d680..bfa26641652 100644 --- a/tests/ert/unit_tests/config/test_queue_config.py +++ b/tests/ert/unit_tests/config/test_queue_config.py @@ -15,6 +15,7 @@ from ert.config.queue_config import ( LocalQueueOptions, LsfQueueOptions, + QueueOptions, SlurmQueueOptions, TorqueQueueOptions, ) @@ -509,3 +510,15 @@ def test_driver_initialization_from_defaults(queue_system): LocalDriver(**LocalQueueOptions().driver_options) if queue_system == QueueSystem.SLURM: SlurmDriver(**SlurmQueueOptions().driver_options) + + +@pytest.mark.parametrize( + "venv, expected", [("my_env", "source my_env/bin/activate"), (None, "")] +) +def test_default_activate_script_generation(expected, monkeypatch, venv): + if venv: + monkeypatch.setenv("VIRTUAL_ENV", venv) + else: + monkeypatch.delenv("VIRTUAL_ENV", raising=False) + options = QueueOptions(name="local") + assert options.activate_script == expected diff --git a/tests/ert/unit_tests/plugins/test_plugin_manager.py b/tests/ert/unit_tests/plugins/test_plugin_manager.py index 49d8c8fc037..97f479c252c 100644 --- a/tests/ert/unit_tests/plugins/test_plugin_manager.py +++ b/tests/ert/unit_tests/plugins/test_plugin_manager.py @@ -1,14 +1,15 @@ import json import logging import tempfile -from unittest.mock import Mock +from functools import partial +from unittest.mock import Mock, patch import pytest from opentelemetry.sdk.trace import TracerProvider import ert.plugins.hook_implementations -from ert import plugin -from ert.plugins import ErtPluginManager +from ert.config import ErtConfig +from ert.plugins import ErtPluginManager, plugin from tests.ert.unit_tests.plugins import dummy_plugins from tests.ert.unit_tests.plugins.dummy_plugins import ( DummyFMStep, @@ -279,3 +280,44 @@ def test_that_forward_model_step_is_registered(tmpdir): with tmpdir.as_cwd(): pm = ErtPluginManager(plugins=[dummy_plugins]) assert pm.forward_model_steps == [DummyFMStep] + + +class ActivatePlugin: + @plugin(name="first") + def activate_script(self): + return "source something" + + +class AnotherActivatePlugin: + @plugin(name="second") + def activate_script(self): + return "Something" + + +class EmptyActivatePlugin: + @plugin(name="empty") + def activate_script(self): + return None + + +@pytest.mark.parametrize( + "plugins", [[ActivatePlugin()], [ActivatePlugin(), EmptyActivatePlugin()]] +) +def test_activate_script_hook(plugins): + pm = ErtPluginManager(plugins=plugins) + assert pm.activate_script() == "source something" + + +def test_multiple_activate_script_hook(): + pm = ErtPluginManager(plugins=[ActivatePlugin(), AnotherActivatePlugin()]) + with pytest.raises(ValueError, match="one activate script is allowed"): + pm.activate_script() + + +def test_activate_script_plugin_integration(): + patched = partial( + ert.config.ert_config.ErtPluginManager, plugins=[ActivatePlugin()] + ) + with patch("ert.config.ert_config.ErtPluginManager", patched): + config = ErtConfig.with_plugins().from_file_contents("NUM_REALIZATIONS 1\n") + assert config.queue_config.queue_options.activate_script == "source something" diff --git a/tests/everest/test_res_initialization.py b/tests/everest/test_res_initialization.py index 7f825723232..214e9f19418 100644 --- a/tests/everest/test_res_initialization.py +++ b/tests/everest/test_res_initialization.py @@ -279,7 +279,9 @@ def test_snake_everest_to_ert_torque(copy_test_data_to_tmp): qc = ert_config.queue_config qo = qc.queue_options assert qc.queue_system == "TORQUE" - assert {k: v for k, v in qo.driver_options.items() if v is not None} == { + driver_options = qo.driver_options + driver_options.pop("activate_script") + assert {k: v for k, v in driver_options.items() if v is not None} == { "project_code": "snake_oil_pc", "qsub_cmd": "qsub", "qstat_cmd": "qstat",