Skip to content

Commit

Permalink
Refactor BaseRunModel
Browse files Browse the repository at this point in the history
  • Loading branch information
frode-aarstad committed Jan 13, 2025
1 parent 051fa31 commit 64aab1f
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 90 deletions.
86 changes: 61 additions & 25 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from contextlib import contextmanager
from pathlib import Path
from queue import SimpleQueue
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

import numpy as np

Expand All @@ -31,8 +31,11 @@
AnalysisDataEvent,
AnalysisErrorEvent,
)
from ert.config import ErtConfig, HookRuntime, QueueSystem
from ert.config import HookRuntime, QueueSystem
from ert.config.analysis_module import BaseSettings
from ert.config.forward_model_step import ForwardModelStep
from ert.config.model_config import ModelConfig
from ert.config.workflow import Workflow
from ert.enkf_main import _seed_sequence, create_run_path
from ert.ensemble_evaluator import Ensemble as EEEnsemble
from ert.ensemble_evaluator import (
Expand All @@ -58,6 +61,7 @@
from ert.mode_definitions import MODULE_MODE
from ert.runpaths import Runpaths
from ert.storage import Ensemble, Storage
from ert.substitutions import Substitutions
from ert.trace import tracer
from ert.workflow_runner import WorkflowRunner

Expand Down Expand Up @@ -131,10 +135,18 @@ def captured_logs(
class BaseRunModel(ABC):
def __init__(
self,
config: ErtConfig,
storage: Storage,
runpath_file: Path,
user_config_file: Path,
env_vars: dict[str, str],
env_pr_fm_step: dict[str, dict[str, Any]],
model_config: ModelConfig,
queue_config: QueueConfig,
forward_model_steps: list[ForwardModelStep],
status_queue: SimpleQueue[StatusEvents],
substitutions: Substitutions,
templates: list[tuple[str, str]],
hooked_workflows: defaultdict[HookRuntime, list[Workflow]],
active_realizations: list[bool],
total_iterations: int = 1,
start_iteration: int = 0,
Expand All @@ -147,27 +159,35 @@ def __init__(
the forward model and passing events back through the supplied queue.
"""
self._total_iterations = total_iterations
config.analysis_config.num_iterations = total_iterations

self.start_time: int | None = None
self.stop_time: int | None = None
self._queue_config: QueueConfig = queue_config
self._initial_realizations_mask: list[bool] = copy.copy(active_realizations)
self._completed_realizations_mask: list[bool] = []
self.support_restart: bool = True
self.ert_config = config
self._storage = storage
self._context_env: dict[str, str] = {}
self.random_seed: int = _seed_sequence(random_seed)
self.rng = np.random.default_rng(self.random_seed)
self.substitutions = config.substitutions
self._substitutions: Substitutions = substitutions
self._model_config: ModelConfig = model_config
self._runpath_file: Path = runpath_file
self._forward_model_steps: list[ForwardModelStep] = forward_model_steps
self._user_config_file: Path = user_config_file
self._templates: list[tuple[str, str]] = templates
self._hooked_workflows: defaultdict[HookRuntime, list[Workflow]] = (
hooked_workflows
)

self._env_vars: dict[str, str] = env_vars
self._env_pr_fm_step: dict[str, dict[str, Any]] = env_pr_fm_step

self.run_paths = Runpaths(
jobname_format=config.model_config.jobname_format_string,
runpath_format=config.model_config.runpath_format_string,
filename=str(config.runpath_file),
substitutions=self.substitutions,
eclbase=config.model_config.eclbase_format_string,
jobname_format=self._model_config.jobname_format_string,
runpath_format=self._model_config.runpath_format_string,
filename=str(self._runpath_file),
substitutions=self._substitutions,
eclbase=self._model_config.eclbase_format_string,
)
self._iter_snapshot: dict[int, EnsembleSnapshot] = {}
self._status_queue = status_queue
Expand Down Expand Up @@ -603,12 +623,12 @@ def _build_ensemble(
Realization(
active=run_arg.active,
iens=run_arg.iens,
fm_steps=self.ert_config.forward_model_steps,
fm_steps=self._forward_model_steps,
max_runtime=self._queue_config.max_runtime,
run_arg=run_arg,
num_cpu=self._queue_config.preferred_num_cpu,
job_script=self.ert_config.queue_config.job_script,
realization_memory=self.ert_config.queue_config.realization_memory,
job_script=self._queue_config.job_script,
realization_memory=self._queue_config.realization_memory,
)
)
return EEEnsemble(
Expand Down Expand Up @@ -676,7 +696,7 @@ def run_workflows(
storage: Storage | None = None,
ensemble: Ensemble | None = None,
) -> None:
for workflow in self.ert_config.hooked_workflows[runtime]:
for workflow in self._hooked_workflows[runtime]:
WorkflowRunner(workflow, storage, ensemble).run_blocking()

def _evaluate_and_postprocess(
Expand All @@ -688,13 +708,13 @@ def _evaluate_and_postprocess(
create_run_path(
run_args=run_args,
ensemble=ensemble,
user_config_file=self.ert_config.user_config_file,
env_vars=self.ert_config.env_vars,
env_pr_fm_step=self.ert_config.env_pr_fm_step,
forward_model_steps=self.ert_config.forward_model_steps,
substitutions=self.ert_config.substitutions,
templates=self.ert_config.ert_templates,
model_config=self.ert_config.model_config,
user_config_file=str(self._user_config_file),
env_vars=self._env_vars,
env_pr_fm_step=self._env_pr_fm_step,
forward_model_steps=self._forward_model_steps,
substitutions=self._substitutions,
templates=self._templates,
model_config=self._model_config,
runpaths=self.run_paths,
context_env=self._context_env,
)
Expand Down Expand Up @@ -735,10 +755,18 @@ def __init__(
self,
analysis_settings: BaseSettings,
update_settings: UpdateSettings,
config: ErtConfig,
storage: Storage,
runpath_file: Path,
user_config_file: Path,
env_vars: dict[str, str],
env_pr_fm_step: dict[str, dict[str, Any]],
model_config: ModelConfig,
queue_config: QueueConfig,
forward_model_steps: list[ForwardModelStep],
status_queue: SimpleQueue[StatusEvents],
substitutions: Substitutions,
templates: list[tuple[str, str]],
hooked_workflows: defaultdict[HookRuntime, list[Workflow]],
active_realizations: list[bool],
total_iterations: int,
start_iteration: int,
Expand All @@ -749,10 +777,18 @@ def __init__(
self._update_settings: UpdateSettings = update_settings

super().__init__(
config,
storage,
runpath_file,
user_config_file,
env_vars,
env_pr_fm_step,
model_config,
queue_config,
forward_model_steps,
status_queue,
substitutions,
templates,
hooked_workflows,
active_realizations=active_realizations,
total_iterations=total_iterations,
start_iteration=start_iteration,
Expand Down
24 changes: 19 additions & 5 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from pathlib import Path
from queue import SimpleQueue
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -46,11 +47,24 @@ def __init__(
self.experiment: Experiment | None = None
self.ensemble: Ensemble | None = None

self._design_matrix = config.analysis_config.design_matrix
self._observations = config.observations
self._parameter_configuration = config.ensemble_config.parameter_configuration
self._response_configuration = config.ensemble_config.response_configuration

super().__init__(
config,
storage,
config.runpath_file,
Path(config.user_config_file),
config.env_vars,
config.env_pr_fm_step,
config.model_config,
queue_config,
config.forward_model_steps,
status_queue,
config.substitutions,
config.ert_templates,
config.hooked_workflows,
total_iterations=1,
active_realizations=active_realizations,
random_seed=random_seed,
Expand All @@ -67,8 +81,8 @@ def run_experiment(
self.restart = restart
# If design matrix is present, we try to merge design matrix parameters
# to the experiment parameters and set new active realizations
parameters_config = self.ert_config.ensemble_config.parameter_configuration
design_matrix = self.ert_config.analysis_config.design_matrix
parameters_config = self._parameter_configuration
design_matrix = self._design_matrix
design_matrix_group = None
if design_matrix is not None:
try:
Expand All @@ -87,8 +101,8 @@ def run_experiment(
if design_matrix_group is not None
else parameters_config
),
observations=self.ert_config.observations,
responses=self.ert_config.ensemble_config.response_configuration,
observations=self._observations,
responses=self._response_configuration,
)
self.ensemble = self._storage.create_ensemble(
self.experiment,
Expand Down
21 changes: 17 additions & 4 deletions src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from pathlib import Path
from queue import SimpleQueue
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -42,10 +43,18 @@ def __init__(
super().__init__(
es_settings,
update_settings,
config,
storage,
config.runpath_file,
Path(config.user_config_file),
config.env_vars,
config.env_pr_fm_step,
config.model_config,
queue_config,
config.forward_model_steps,
status_queue,
config.substitutions,
config.ert_templates,
config.hooked_workflows,
active_realizations=active_realizations,
start_iteration=0,
total_iterations=2,
Expand All @@ -57,6 +66,10 @@ def __init__(

self.support_restart = False

self._parameter_configuration = config.ensemble_config.parameter_configuration
self._observations = config.observations
self._response_configuration = config.ensemble_config.response_configuration

@tracer.start_as_current_span(f"{__name__}.run_experiment")
def run_experiment(
self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False
Expand All @@ -66,9 +79,9 @@ def run_experiment(
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
ensemble_format = self.target_ensemble_format
experiment = self._storage.create_experiment(
parameters=self.ert_config.ensemble_config.parameter_configuration,
observations=self.ert_config.observations,
responses=self.ert_config.ensemble_config.response_configuration,
parameters=self._parameter_configuration,
observations=self._observations,
responses=self._response_configuration,
name=self.experiment_name,
)

Expand Down
12 changes: 11 additions & 1 deletion src/ert/run_models/evaluate_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import UUID

Expand Down Expand Up @@ -47,11 +48,20 @@ def __init__(
self.ensemble = storage.get_ensemble(UUID(ensemble_id))
except KeyError as err:
raise ValueError(f"No ensemble: {ensemble_id}") from err

super().__init__(
config,
storage,
config.runpath_file,
Path(config.user_config_file),
config.env_vars,
config.env_pr_fm_step,
config.model_config,
queue_config,
config.forward_model_steps,
status_queue,
config.substitutions,
config.ert_templates,
config.hooked_workflows,
start_iteration=self.ensemble.iteration,
total_iterations=1,
active_realizations=active_realizations,
Expand Down
30 changes: 21 additions & 9 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,26 @@ def __init__(

storage = open_storage(config.ens_path, mode="w")
status_queue: queue.SimpleQueue[StatusEvents] = queue.SimpleQueue()

super().__init__(
config,
storage,
config.runpath_file,
Path(config.user_config_file),
config.env_vars,
config.env_pr_fm_step,
config.model_config,
config.queue_config,
config.forward_model_steps,
status_queue,
config.substitutions,
config.ert_templates,
config.hooked_workflows,
active_realizations=[], # Set dynamically in run_forward_model()
)
self.support_restart = False
self._parameter_configuration = config.ensemble_config.parameter_configuration
self._parameter_configs = config.ensemble_config.parameter_configs
self._response_configuration = config.ensemble_config.response_configuration

@classmethod
def create(
Expand Down Expand Up @@ -187,8 +199,8 @@ def run_experiment(
self._eval_server_cfg = evaluator_server_config
self._experiment = self._storage.create_experiment(
name=f"EnOpt@{datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
parameters=self.ert_config.ensemble_config.parameter_configuration,
responses=self.ert_config.ensemble_config.response_configuration,
parameters=self._parameter_configuration,
responses=self._response_configuration,
)

# Initialize the ropt optimizer:
Expand Down Expand Up @@ -494,7 +506,7 @@ def _check_suffix(
raise KeyError(err_msg)

for control_name, control in controls.items():
ext_config = self.ert_config.ensemble_config.parameter_configs[control_name]
ext_config = self._parameter_configs[control_name]
if isinstance(ext_config, ExtParamConfig):
if len(ext_config) != len(control.keys()):
raise KeyError(
Expand All @@ -515,7 +527,7 @@ def _get_run_args(
evaluator_context: EvaluatorContext,
batch_data: dict[int, Any],
) -> list[RunArg]:
substitutions = self.ert_config.substitutions
substitutions = self._substitutions
substitutions["<BATCH_NAME>"] = ensemble.name
self.active_realizations = [True] * len(batch_data)
for sim_id, control_idx in enumerate(batch_data.keys()):
Expand All @@ -525,11 +537,11 @@ def _get_run_args(
]
)
run_paths = Runpaths(
jobname_format=self.ert_config.model_config.jobname_format_string,
runpath_format=self.ert_config.model_config.runpath_format_string,
filename=str(self.ert_config.runpath_file),
jobname_format=self._model_config.jobname_format_string,
runpath_format=self._model_config.runpath_format_string,
filename=str(self._runpath_file),
substitutions=substitutions,
eclbase=self.ert_config.model_config.eclbase_format_string,
eclbase=self._model_config.eclbase_format_string,
)
return create_run_arguments(
run_paths,
Expand Down
Loading

0 comments on commit 64aab1f

Please sign in to comment.