Skip to content

Commit

Permalink
Fix runpaths bug in evaluate ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Sep 2, 2024
1 parent e9c4a92 commit cd017c4
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 23 deletions.
4 changes: 3 additions & 1 deletion src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,9 @@ def _build_ensemble(
def paths(self) -> List[str]:
run_paths = []
active_realizations = np.where(self.active_realizations)[0]
for iteration in range(self.start_iteration, self._total_iterations):
for iteration in range(
self.start_iteration, self._total_iterations + self.start_iteration
):
run_paths.extend(self.run_paths.get_paths(active_realizations, iteration))
return run_paths

Expand Down
15 changes: 7 additions & 8 deletions src/ert/run_models/evaluate_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Storage
from ert.storage import Storage

from ..run_arg import create_run_arguments
from . import BaseRunModel
Expand Down Expand Up @@ -41,27 +41,26 @@ def __init__(
queue_config: QueueConfig,
status_queue: SimpleQueue[StatusEvents],
):
try:
self.ensemble = storage.get_ensemble(UUID(ensemble_id))
except KeyError as err:
raise ValueError(f"No ensemble: {ensemble_id}") from err
super().__init__(
config,
storage,
queue_config,
status_queue,
start_iteration=0,
start_iteration=self.ensemble.iteration,
total_iterations=1,
active_realizations=active_realizations,
minimum_required_realizations=minimum_required_realizations,
random_seed=random_seed,
)
self.ensemble_id = ensemble_id

def run_experiment(
self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False
) -> None:
ensemble_id = self.ensemble_id
ensemble_uuid = UUID(ensemble_id)
ensemble = self._storage.get_ensemble(ensemble_uuid)
assert isinstance(ensemble, Ensemble)

ensemble = self.ensemble
experiment = ensemble.experiment
self.set_env_key("_ERT_EXPERIMENT_ID", str(experiment.id))
self.set_env_key("_ERT_ENSEMBLE_ID", str(ensemble.id))
Expand Down
27 changes: 14 additions & 13 deletions src/ert/run_models/manual_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def __init__(
update_settings: UpdateSettings,
status_queue: SimpleQueue[StatusEvents],
):
try:
prior_id = UUID(ensemble_id)
prior = storage.get_ensemble(prior_id)
except (KeyError, ValueError) as err:
raise ErtRunError(
f"Prior ensemble with ID: {prior_id} does not exists"
) from err

super().__init__(
es_settings,
update_settings,
Expand All @@ -46,30 +54,23 @@ def __init__(
status_queue,
active_realizations=active_realizations,
total_iterations=1,
start_iteration=0,
start_iteration=prior.iteration,
random_seed=random_seed,
minimum_required_realizations=minimum_required_realizations,
)
self.prior_ensemble_id = ensemble_id
self.prior = prior
self.target_ensemble_format = target_ensemble
self.support_restart = False

def run_experiment(
self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False
) -> None:
logger.info("Running manual update")
self.set_env_key("_ERT_EXPERIMENT_ID", str(self.prior.experiment.id))
self.set_env_key("_ERT_ENSEMBLE_ID", str(self.prior.id))

ensemble_format = self.target_ensemble_format
try:
ensemble_id = UUID(self.prior_ensemble_id)
prior = self._storage.get_ensemble(ensemble_id)
experiment = prior.experiment
self.set_env_key("_ERT_EXPERIMENT_ID", str(experiment.id))
self.set_env_key("_ERT_ENSEMBLE_ID", str(prior.id))
except (KeyError, ValueError) as err:
raise ErtRunError(
f"Prior ensemble with ID: {ensemble_id} does not exists"
) from err
self.update(prior, ensemble_format % (prior.iteration + 1))
self.update(self.prior, ensemble_format % (self.prior.iteration + 1))

@classmethod
def name(cls) -> str:
Expand Down
4 changes: 3 additions & 1 deletion src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def __init__(
self.restart_run = restart_run
self.prior_ensemble_id = prior_ensemble_id
start_iteration = 0
total_iterations = len(self.weights) + 1
if self.restart_run:
if not self.prior_ensemble_id:
raise ValueError("For restart run, prior ensemble must be set")
start_iteration = storage.get_ensemble(prior_ensemble_id).iteration + 1
total_iterations -= start_iteration
elif not self.experiment_name:
raise ValueError("For non-restart run, experiment name must be set")
super().__init__(
Expand All @@ -68,7 +70,7 @@ def __init__(
queue_config,
status_queue,
active_realizations=active_realizations,
total_iterations=len(self.weights) + 1,
total_iterations=total_iterations,
start_iteration=start_iteration,
random_seed=random_seed,
minimum_required_realizations=minimum_required_realizations,
Expand Down
30 changes: 30 additions & 0 deletions tests/unit_tests/run_models/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
SingleTestRun,
model_factory,
)
from ert.run_models.evaluate_ensemble import EvaluateEnsemble


@pytest.mark.parametrize(
Expand Down Expand Up @@ -244,3 +245,32 @@ def test_num_realizations_specified_incorrectly_raises(analysis_mode):
match="Number of active realizations must be at least 2 for an update step",
):
analysis_mode(config, MagicMock(), args, MagicMock(), MagicMock())


@pytest.mark.parametrize(
"ensemble_iteration, expected_path",
[
[0, ["realization-0/iter-0"]],
[1, ["realization-0/iter-1"]],
[2, ["realization-0/iter-2"]],
[100, ["realization-0/iter-100"]],
],
)
def test_evaluate_ensemble_paths(
tmp_path, monkeypatch, ensemble_iteration, expected_path
):
monkeypatch.chdir(tmp_path)
monkeypatch.setattr(
ert.run_models.base_run_model.BaseRunModel, "validate", MagicMock()
)
storage_mock = MagicMock()
ensemble_mock = MagicMock()
ensemble_mock.iteration = ensemble_iteration
config = ErtConfig(model_config=ModelConfig(num_realizations=2))
storage_mock.get_ensemble.return_value = ensemble_mock
model = EvaluateEnsemble(
[True], 1, str(uuid1(0)), 1234, config, storage_mock, MagicMock(), MagicMock()
)
base_path = tmp_path / "simulations"
expected_path = [str(base_path / expected) for expected in expected_path]
assert set(model.paths) == set(expected_path)

0 comments on commit cd017c4

Please sign in to comment.