Skip to content

Commit

Permalink
Add post/pre experiment simulation hooks
Browse files Browse the repository at this point in the history
* Add post/pre experiment simulation hooks
* Add docs for PRE/POST_EXPERIMENT hooks
  • Loading branch information
yngve-sk authored Jan 6, 2025
1 parent 6c9db0b commit d24f588
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 17 deletions.
26 changes: 16 additions & 10 deletions docs/ert/reference/workflows/complete_workflows.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,29 +47,35 @@ With the keyword :code:`HOOK_WORKFLOW` you can configure workflow
points during ERTs execution. Currently there are five points in ERTs
flow of execution where you can hook in a workflow:

- Before the simulations (all forward models for a realization) start using :code:`PRE_SIMULATION`,
- Before the experiment starts using :code:`PRE_EXPERIMENT`
- before the simulations (all forward models for a realization) start using :code:`PRE_SIMULATION`,
- after all the simulations have completed using :code:`POST_SIMULATION`,
- before the update step using :code:`PRE_UPDATE`
- after the update step using :code:`POST_UPDATE` and
- only before the first update using :code:`PRE_FIRST_UPDATE`.
- after the experiment has completed using :code:`POST_EXPERIMENT`

For non-iterative algorithms, :code:`PRE_FIRST_UPDATE` is equal to :code:`PRE_UPDATE`.
The :code:`POST_SIMULATION` hook is typically used to trigger QC workflows.

::

HOOK_WORKFLOW initWFLOW PRE_SIMULATION
HOOK_WORKFLOW preUpdateWFLOW PRE_UPDATE
HOOK_WORKFLOW postUpdateWFLOW POST_UPDATE
HOOK_WORKFLOW QC_WFLOW1 POST_SIMULATION
HOOK_WORKFLOW QC_WFLOW2 POST_SIMULATION

In this example the workflow :code:`initWFLOW` will run after all the
HOOK_WORKFLOW preExperimentWFLOW PRE_EXPERIMENT
HOOK_WORKFLOW initWFLOW PRE_SIMULATION
HOOK_WORKFLOW preUpdateWFLOW PRE_UPDATE
HOOK_WORKFLOW postUpdateWFLOW POST_UPDATE
HOOK_WORKFLOW QC_WFLOW1 POST_SIMULATION
HOOK_WORKFLOW QC_WFLOW2 POST_SIMULATION
HOOK_WORKFLOW postExperimentWFLOW POST_EXPERIMENT

In this example the workflow, :code:`preExperimentWFLOW` will run,
then :code:`initWFLOW` will run at the start of every iteration, when
simulation directories have been created, just before the forward
model is submitted to the queue. The workflow :code:`preUpdateWFLOW`
will be run before the update step and :code:`postUpdateWFLOW` will be
run after the update step. When all the simulations have completed the
run after the update step. At the end of each forward model run, the
two workflows :code:`QC_WFLOW1` and :code:`QC_WFLOW2` will be run.
After all iterations are complete, the :code:`postExperimentWFLOW` will
run.

Observe that the workflows being 'hooked in' with the
:code:`HOOK_WORKFLOW` must be loaded with the :code:`LOAD_WORKFLOW`
Expand Down
2 changes: 2 additions & 0 deletions src/ert/config/parsing/hook_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ class HookRuntime(StrEnum):
PRE_UPDATE = "PRE_UPDATE"
POST_UPDATE = "POST_UPDATE"
PRE_FIRST_UPDATE = "PRE_FIRST_UPDATE"
PRE_EXPERIMENT = "PRE_EXPERIMENT"
POST_EXPERIMENT = "POST_EXPERIMENT"
5 changes: 4 additions & 1 deletion src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,10 @@ def validate_successful_realizations_count(self) -> None:

@tracer.start_as_current_span(f"{__name__}.run_workflows")
def run_workflows(
self, runtime: HookRuntime, storage: Storage, ensemble: Ensemble
self,
runtime: HookRuntime,
storage: Storage | None = None,
ensemble: Ensemble | None = None,
) -> None:
for workflow in self.ert_config.hooked_workflows[runtime]:
WorkflowRunner(workflow, storage, ensemble).run_blocking()
Expand Down
4 changes: 3 additions & 1 deletion src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from ert.config import ConfigValidationError
from ert.config import ConfigValidationError, HookRuntime
from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Experiment, Storage
Expand Down Expand Up @@ -79,6 +79,7 @@ def run_experiment(
raise ErtRunError(str(exc)) from exc

if not restart:
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
self.experiment = self._storage.create_experiment(
name=self.experiment_name,
parameters=(
Expand Down Expand Up @@ -128,6 +129,7 @@ def run_experiment(
self.ensemble,
evaluator_server_config,
)
self.run_workflows(HookRuntime.POST_EXPERIMENT)

@classmethod
def name(cls) -> str:
Expand Down
4 changes: 3 additions & 1 deletion src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from ert.config import ErtConfig
from ert.config import ErtConfig, HookRuntime
from ert.enkf_main import sample_prior
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Storage
Expand Down Expand Up @@ -63,6 +63,7 @@ def run_experiment(
) -> None:
self.log_at_startup()
self.restart = restart
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
ensemble_format = self.target_ensemble_format
experiment = self._storage.create_experiment(
parameters=self.ert_config.ensemble_config.parameter_configuration,
Expand Down Expand Up @@ -108,6 +109,7 @@ def run_experiment(
posterior,
evaluator_server_config,
)
self.run_workflows(HookRuntime.POST_EXPERIMENT)

@classmethod
def name(cls) -> str:
Expand Down
4 changes: 4 additions & 0 deletions src/ert/run_models/iterated_ensemble_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def run_experiment(
) -> None:
self.log_at_startup()
self.restart = restart
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
target_ensemble_format = self.target_ensemble_format
experiment = self._storage.create_experiment(
parameters=self.ert_config.ensemble_config.parameter_configuration,
Expand Down Expand Up @@ -150,6 +151,7 @@ def run_experiment(
np.where(self.active_realizations)[0],
random_seed=self.random_seed,
)

self._evaluate_and_postprocess(
prior_args,
prior,
Expand Down Expand Up @@ -218,6 +220,8 @@ def run_experiment(
)
prior = posterior

self.run_workflows(HookRuntime.POST_EXPERIMENT)

@classmethod
def name(cls) -> str:
return "Iterated ensemble smoother"
Expand Down
6 changes: 5 additions & 1 deletion src/ert/run_models/multiple_data_assimilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import numpy as np

from ert.config import ErtConfig
from ert.config import ErtConfig, HookRuntime
from ert.enkf_main import sample_prior
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Storage
Expand Down Expand Up @@ -106,6 +106,7 @@ def run_experiment(
f"Prior ensemble with ID: {id} does not exists"
) from err
else:
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
sim_args = {"weights": self._relative_weights}
experiment = self._storage.create_experiment(
parameters=self.ert_config.ensemble_config.parameter_configuration,
Expand All @@ -128,6 +129,7 @@ def run_experiment(
np.array(self.active_realizations, dtype=bool),
ensemble=prior,
)

sample_prior(
prior,
np.where(self.active_realizations)[0],
Expand Down Expand Up @@ -159,6 +161,8 @@ def run_experiment(
)
prior = posterior

self.run_workflows(HookRuntime.POST_EXPERIMENT)

@staticmethod
def parse_weights(weights: str) -> list[float]:
"""Parse weights string and scale weights such that their reciprocals sum
Expand Down
69 changes: 69 additions & 0 deletions tests/ert/ui_tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,75 @@ def test_that_stop_on_fail_workflow_jobs_stop_ert(
run_cli(TEST_RUN_MODE, "--disable-monitor", "poly.ert")


@pytest.mark.usefixtures("copy_poly_case")
def test_that_pre_post_experiment_hook_works(monkeypatch, capsys):
monkeypatch.setattr(_ert.threading, "_can_raise", False)

# The executable
with open("hello_post_exp.sh", "w", encoding="utf-8") as f:
f.write(
dedent("""#!/bin/bash
echo "just sending regards"
""")
)
os.chmod("hello_post_exp.sh", 0o755)

# The workflow job
with open("SAY_HELLO_POST_EXP", "w", encoding="utf-8") as s:
s.write("""
INTERNAL False
EXECUTABLE hello_post_exp.sh
""")

# The workflow
with open("SAY_HELLO_POST_EXP.wf", "w", encoding="utf-8") as s:
s.write("""dump_final_ensemble_id""")

# The executable
with open("hello_pre_exp.sh", "w", encoding="utf-8") as f:
f.write(
dedent("""#!/bin/bash
echo "first"
""")
)
os.chmod("hello_pre_exp.sh", 0o755)

# The workflow job
with open("SAY_HELLO_PRE_EXP", "w", encoding="utf-8") as s:
s.write("""
INTERNAL False
EXECUTABLE hello_pre_exp.sh
""")

# The workflow
with open("SAY_HELLO_PRE_EXP.wf", "w", encoding="utf-8") as s:
s.write("""dump_first_ensemble_id""")

with open("poly.ert", mode="a", encoding="utf-8") as fh:
fh.write(
dedent(
"""
NUM_REALIZATIONS 2
LOAD_WORKFLOW_JOB SAY_HELLO_POST_EXP dump_final_ensemble_id
LOAD_WORKFLOW SAY_HELLO_POST_EXP.wf POST_EXPERIMENT_DUMP
HOOK_WORKFLOW POST_EXPERIMENT_DUMP POST_EXPERIMENT
LOAD_WORKFLOW_JOB SAY_HELLO_PRE_EXP dump_first_ensemble_id
LOAD_WORKFLOW SAY_HELLO_PRE_EXP.wf PRE_EXPERIMENT_DUMP
HOOK_WORKFLOW PRE_EXPERIMENT_DUMP PRE_EXPERIMENT
"""
)
)

for mode in [ITERATIVE_ENSEMBLE_SMOOTHER_MODE, ES_MDA_MODE, ENSEMBLE_SMOOTHER_MODE]:
run_cli(mode, "--disable-monitor", "poly.ert")

captured = capsys.readouterr()
assert "first" in captured.out
assert "just sending regards" in captured.out


@pytest.fixture(name="mock_cli_run")
def fixture_mock_cli_run(monkeypatch):
end_event = Mock()
Expand Down
12 changes: 9 additions & 3 deletions tests/ert/unit_tests/cli/test_model_hook_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def test_hook_call_order_ensemble_smoother(monkeypatch):
test_class.run_experiment(MagicMock())

expected_calls = [
call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER
call(HookRuntime.PRE_EXPERIMENT),
*[call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER],
call(HookRuntime.POST_EXPERIMENT),
]
assert run_wfs_mock.mock_calls == expected_calls

Expand Down Expand Up @@ -95,7 +97,9 @@ def test_hook_call_order_es_mda(monkeypatch):
test_class.run_experiment(MagicMock())

expected_calls = [
call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER
call(HookRuntime.PRE_EXPERIMENT),
*[call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER],
call(HookRuntime.POST_EXPERIMENT),
]
assert run_wfs_mock.mock_calls == expected_calls

Expand Down Expand Up @@ -130,6 +134,8 @@ def test_hook_call_order_iterative_ensemble_smoother(monkeypatch):
test_class.run_experiment(MagicMock())

expected_calls = [
call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER
call(HookRuntime.PRE_EXPERIMENT),
*[call(expected_call, ANY, ANY) for expected_call in EXPECTED_CALL_ORDER],
call(HookRuntime.POST_EXPERIMENT),
]
assert run_wfs_mock.mock_calls == expected_calls

0 comments on commit d24f588

Please sign in to comment.