From cd017c4874d79fb95a3bd1b5ef7c8c76527abad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=98yvind=20Eide?= Date: Thu, 22 Aug 2024 09:15:25 +0200 Subject: [PATCH] Fix runpaths bug in evaluate ensemble --- src/ert/run_models/base_run_model.py | 4 ++- src/ert/run_models/evaluate_ensemble.py | 15 +++++----- src/ert/run_models/manual_update.py | 27 +++++++++-------- .../run_models/multiple_data_assimilation.py | 4 ++- .../run_models/test_model_factory.py | 30 +++++++++++++++++++ 5 files changed, 57 insertions(+), 23 deletions(-) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index fb5080a78f6..79b496548ad 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -557,7 +557,9 @@ def _build_ensemble( def paths(self) -> List[str]: run_paths = [] active_realizations = np.where(self.active_realizations)[0] - for iteration in range(self.start_iteration, self._total_iterations): + for iteration in range( + self.start_iteration, self._total_iterations + self.start_iteration + ): run_paths.extend(self.run_paths.get_paths(active_realizations, iteration)) return run_paths diff --git a/src/ert/run_models/evaluate_ensemble.py b/src/ert/run_models/evaluate_ensemble.py index 90c4a7d340a..84b2d7025a2 100644 --- a/src/ert/run_models/evaluate_ensemble.py +++ b/src/ert/run_models/evaluate_ensemble.py @@ -6,7 +6,7 @@ import numpy as np from ert.ensemble_evaluator import EvaluatorServerConfig -from ert.storage import Ensemble, Storage +from ert.storage import Storage from ..run_arg import create_run_arguments from . import BaseRunModel @@ -41,27 +41,26 @@ def __init__( queue_config: QueueConfig, status_queue: SimpleQueue[StatusEvents], ): + try: + self.ensemble = storage.get_ensemble(UUID(ensemble_id)) + except KeyError as err: + raise ValueError(f"No ensemble: {ensemble_id}") from err super().__init__( config, storage, queue_config, status_queue, - start_iteration=0, + start_iteration=self.ensemble.iteration, total_iterations=1, active_realizations=active_realizations, minimum_required_realizations=minimum_required_realizations, random_seed=random_seed, ) - self.ensemble_id = ensemble_id def run_experiment( self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False ) -> None: - ensemble_id = self.ensemble_id - ensemble_uuid = UUID(ensemble_id) - ensemble = self._storage.get_ensemble(ensemble_uuid) - assert isinstance(ensemble, Ensemble) - + ensemble = self.ensemble experiment = ensemble.experiment self.set_env_key("_ERT_EXPERIMENT_ID", str(experiment.id)) self.set_env_key("_ERT_ENSEMBLE_ID", str(ensemble.id)) diff --git a/src/ert/run_models/manual_update.py b/src/ert/run_models/manual_update.py index c5f80a4bd81..0deeeb3557a 100644 --- a/src/ert/run_models/manual_update.py +++ b/src/ert/run_models/manual_update.py @@ -37,6 +37,14 @@ def __init__( update_settings: UpdateSettings, status_queue: SimpleQueue[StatusEvents], ): + try: + prior_id = UUID(ensemble_id) + prior = storage.get_ensemble(prior_id) + except (KeyError, ValueError) as err: + raise ErtRunError( + f"Prior ensemble with ID: {prior_id} does not exists" + ) from err + super().__init__( es_settings, update_settings, @@ -46,11 +54,11 @@ def __init__( status_queue, active_realizations=active_realizations, total_iterations=1, - start_iteration=0, + start_iteration=prior.iteration, random_seed=random_seed, minimum_required_realizations=minimum_required_realizations, ) - self.prior_ensemble_id = ensemble_id + self.prior = prior self.target_ensemble_format = target_ensemble self.support_restart = False @@ -58,18 +66,11 @@ def run_experiment( self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False ) -> None: logger.info("Running manual update") + self.set_env_key("_ERT_EXPERIMENT_ID", str(self.prior.experiment.id)) + self.set_env_key("_ERT_ENSEMBLE_ID", str(self.prior.id)) + ensemble_format = self.target_ensemble_format - try: - ensemble_id = UUID(self.prior_ensemble_id) - prior = self._storage.get_ensemble(ensemble_id) - experiment = prior.experiment - self.set_env_key("_ERT_EXPERIMENT_ID", str(experiment.id)) - self.set_env_key("_ERT_ENSEMBLE_ID", str(prior.id)) - except (KeyError, ValueError) as err: - raise ErtRunError( - f"Prior ensemble with ID: {ensemble_id} does not exists" - ) from err - self.update(prior, ensemble_format % (prior.iteration + 1)) + self.update(self.prior, ensemble_format % (self.prior.iteration + 1)) @classmethod def name(cls) -> str: diff --git a/src/ert/run_models/multiple_data_assimilation.py b/src/ert/run_models/multiple_data_assimilation.py index df2558b8c89..04f8f3d683e 100644 --- a/src/ert/run_models/multiple_data_assimilation.py +++ b/src/ert/run_models/multiple_data_assimilation.py @@ -54,10 +54,12 @@ def __init__( self.restart_run = restart_run self.prior_ensemble_id = prior_ensemble_id start_iteration = 0 + total_iterations = len(self.weights) + 1 if self.restart_run: if not self.prior_ensemble_id: raise ValueError("For restart run, prior ensemble must be set") start_iteration = storage.get_ensemble(prior_ensemble_id).iteration + 1 + total_iterations -= start_iteration elif not self.experiment_name: raise ValueError("For non-restart run, experiment name must be set") super().__init__( @@ -68,7 +70,7 @@ def __init__( queue_config, status_queue, active_realizations=active_realizations, - total_iterations=len(self.weights) + 1, + total_iterations=total_iterations, start_iteration=start_iteration, random_seed=random_seed, minimum_required_realizations=minimum_required_realizations, diff --git a/tests/unit_tests/run_models/test_model_factory.py b/tests/unit_tests/run_models/test_model_factory.py index b2ee9046246..1b26d253aa9 100644 --- a/tests/unit_tests/run_models/test_model_factory.py +++ b/tests/unit_tests/run_models/test_model_factory.py @@ -15,6 +15,7 @@ SingleTestRun, model_factory, ) +from ert.run_models.evaluate_ensemble import EvaluateEnsemble @pytest.mark.parametrize( @@ -244,3 +245,32 @@ def test_num_realizations_specified_incorrectly_raises(analysis_mode): match="Number of active realizations must be at least 2 for an update step", ): analysis_mode(config, MagicMock(), args, MagicMock(), MagicMock()) + + +@pytest.mark.parametrize( + "ensemble_iteration, expected_path", + [ + [0, ["realization-0/iter-0"]], + [1, ["realization-0/iter-1"]], + [2, ["realization-0/iter-2"]], + [100, ["realization-0/iter-100"]], + ], +) +def test_evaluate_ensemble_paths( + tmp_path, monkeypatch, ensemble_iteration, expected_path +): + monkeypatch.chdir(tmp_path) + monkeypatch.setattr( + ert.run_models.base_run_model.BaseRunModel, "validate", MagicMock() + ) + storage_mock = MagicMock() + ensemble_mock = MagicMock() + ensemble_mock.iteration = ensemble_iteration + config = ErtConfig(model_config=ModelConfig(num_realizations=2)) + storage_mock.get_ensemble.return_value = ensemble_mock + model = EvaluateEnsemble( + [True], 1, str(uuid1(0)), 1234, config, storage_mock, MagicMock(), MagicMock() + ) + base_path = tmp_path / "simulations" + expected_path = [str(base_path / expected) for expected in expected_path] + assert set(model.paths) == set(expected_path)