Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor BaseRunModel #9676

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 61 additions & 25 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from contextlib import contextmanager
from pathlib import Path
from queue import SimpleQueue
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, Any, cast

import numpy as np

Expand All @@ -31,8 +31,11 @@
AnalysisDataEvent,
AnalysisErrorEvent,
)
from ert.config import ErtConfig, HookRuntime, QueueSystem
from ert.config import HookRuntime, QueueSystem
from ert.config.analysis_module import BaseSettings
from ert.config.forward_model_step import ForwardModelStep
from ert.config.model_config import ModelConfig
from ert.config.workflow import Workflow
from ert.enkf_main import _seed_sequence, create_run_path
from ert.ensemble_evaluator import Ensemble as EEEnsemble
from ert.ensemble_evaluator import (
Expand All @@ -58,6 +61,7 @@
from ert.mode_definitions import MODULE_MODE
from ert.runpaths import Runpaths
from ert.storage import Ensemble, Storage
from ert.substitutions import Substitutions
from ert.trace import tracer
from ert.workflow_runner import WorkflowRunner

Expand Down Expand Up @@ -131,10 +135,18 @@ def captured_logs(
class BaseRunModel(ABC):
def __init__(
self,
config: ErtConfig,
storage: Storage,
runpath_file: Path,
user_config_file: Path,
env_vars: dict[str, str],
env_pr_fm_step: dict[str, dict[str, Any]],
model_config: ModelConfig,
queue_config: QueueConfig,
forward_model_steps: list[ForwardModelStep],
status_queue: SimpleQueue[StatusEvents],
substitutions: Substitutions,
templates: list[tuple[str, str]],
hooked_workflows: defaultdict[HookRuntime, list[Workflow]],
active_realizations: list[bool],
total_iterations: int = 1,
start_iteration: int = 0,
Expand All @@ -147,27 +159,35 @@ def __init__(
the forward model and passing events back through the supplied queue.
"""
self._total_iterations = total_iterations
config.analysis_config.num_iterations = total_iterations

self.start_time: int | None = None
self.stop_time: int | None = None
self._queue_config: QueueConfig = queue_config
self._initial_realizations_mask: list[bool] = copy.copy(active_realizations)
self._completed_realizations_mask: list[bool] = []
self.support_restart: bool = True
self.ert_config = config
self._storage = storage
self._context_env: dict[str, str] = {}
self.random_seed: int = _seed_sequence(random_seed)
self.rng = np.random.default_rng(self.random_seed)
self.substitutions = config.substitutions
self._substitutions: Substitutions = substitutions
self._model_config: ModelConfig = model_config
self._runpath_file: Path = runpath_file
self._forward_model_steps: list[ForwardModelStep] = forward_model_steps
self._user_config_file: Path = user_config_file
self._templates: list[tuple[str, str]] = templates
self._hooked_workflows: defaultdict[HookRuntime, list[Workflow]] = (
hooked_workflows
)

self._env_vars: dict[str, str] = env_vars
self._env_pr_fm_step: dict[str, dict[str, Any]] = env_pr_fm_step

self.run_paths = Runpaths(
jobname_format=config.model_config.jobname_format_string,
runpath_format=config.model_config.runpath_format_string,
filename=str(config.runpath_file),
substitutions=self.substitutions,
eclbase=config.model_config.eclbase_format_string,
jobname_format=self._model_config.jobname_format_string,
runpath_format=self._model_config.runpath_format_string,
filename=str(self._runpath_file),
substitutions=self._substitutions,
eclbase=self._model_config.eclbase_format_string,
)
self._iter_snapshot: dict[int, EnsembleSnapshot] = {}
self._status_queue = status_queue
Expand Down Expand Up @@ -603,12 +623,12 @@ def _build_ensemble(
Realization(
active=run_arg.active,
iens=run_arg.iens,
fm_steps=self.ert_config.forward_model_steps,
fm_steps=self._forward_model_steps,
max_runtime=self._queue_config.max_runtime,
run_arg=run_arg,
num_cpu=self._queue_config.preferred_num_cpu,
job_script=self.ert_config.queue_config.job_script,
realization_memory=self.ert_config.queue_config.realization_memory,
job_script=self._queue_config.job_script,
realization_memory=self._queue_config.realization_memory,
)
)
return EEEnsemble(
Expand Down Expand Up @@ -676,7 +696,7 @@ def run_workflows(
storage: Storage | None = None,
ensemble: Ensemble | None = None,
) -> None:
for workflow in self.ert_config.hooked_workflows[runtime]:
for workflow in self._hooked_workflows[runtime]:
WorkflowRunner(workflow, storage, ensemble).run_blocking()

def _evaluate_and_postprocess(
Expand All @@ -688,13 +708,13 @@ def _evaluate_and_postprocess(
create_run_path(
run_args=run_args,
ensemble=ensemble,
user_config_file=self.ert_config.user_config_file,
env_vars=self.ert_config.env_vars,
env_pr_fm_step=self.ert_config.env_pr_fm_step,
forward_model_steps=self.ert_config.forward_model_steps,
substitutions=self.ert_config.substitutions,
templates=self.ert_config.ert_templates,
model_config=self.ert_config.model_config,
user_config_file=str(self._user_config_file),
env_vars=self._env_vars,
env_pr_fm_step=self._env_pr_fm_step,
forward_model_steps=self._forward_model_steps,
substitutions=self._substitutions,
templates=self._templates,
model_config=self._model_config,
runpaths=self.run_paths,
context_env=self._context_env,
)
Expand Down Expand Up @@ -735,10 +755,18 @@ def __init__(
self,
analysis_settings: BaseSettings,
update_settings: UpdateSettings,
config: ErtConfig,
storage: Storage,
runpath_file: Path,
user_config_file: Path,
env_vars: dict[str, str],
env_pr_fm_step: dict[str, dict[str, Any]],
model_config: ModelConfig,
queue_config: QueueConfig,
forward_model_steps: list[ForwardModelStep],
status_queue: SimpleQueue[StatusEvents],
substitutions: Substitutions,
templates: list[tuple[str, str]],
hooked_workflows: defaultdict[HookRuntime, list[Workflow]],
active_realizations: list[bool],
total_iterations: int,
start_iteration: int,
Expand All @@ -749,10 +777,18 @@ def __init__(
self._update_settings: UpdateSettings = update_settings

super().__init__(
config,
storage,
runpath_file,
user_config_file,
env_vars,
env_pr_fm_step,
model_config,
queue_config,
forward_model_steps,
status_queue,
substitutions,
templates,
hooked_workflows,
active_realizations=active_realizations,
total_iterations=total_iterations,
start_iteration=start_iteration,
Expand Down
24 changes: 19 additions & 5 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from pathlib import Path
from queue import SimpleQueue
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -46,11 +47,24 @@ def __init__(
self.experiment: Experiment | None = None
self.ensemble: Ensemble | None = None

self._design_matrix = config.analysis_config.design_matrix
self._observations = config.observations
self._parameter_configuration = config.ensemble_config.parameter_configuration
self._response_configuration = config.ensemble_config.response_configuration

super().__init__(
config,
storage,
config.runpath_file,
Path(config.user_config_file),
config.env_vars,
config.env_pr_fm_step,
config.model_config,
queue_config,
config.forward_model_steps,
status_queue,
config.substitutions,
config.ert_templates,
config.hooked_workflows,
total_iterations=1,
active_realizations=active_realizations,
random_seed=random_seed,
Expand All @@ -67,8 +81,8 @@ def run_experiment(
self.restart = restart
# If design matrix is present, we try to merge design matrix parameters
# to the experiment parameters and set new active realizations
parameters_config = self.ert_config.ensemble_config.parameter_configuration
design_matrix = self.ert_config.analysis_config.design_matrix
parameters_config = self._parameter_configuration
design_matrix = self._design_matrix
design_matrix_group = None
if design_matrix is not None:
try:
Expand All @@ -87,8 +101,8 @@ def run_experiment(
if design_matrix_group is not None
else parameters_config
),
observations=self.ert_config.observations,
responses=self.ert_config.ensemble_config.response_configuration,
observations=self._observations,
responses=self._response_configuration,
)
self.ensemble = self._storage.create_ensemble(
self.experiment,
Expand Down
21 changes: 17 additions & 4 deletions src/ert/run_models/ensemble_smoother.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from pathlib import Path
from queue import SimpleQueue
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -42,10 +43,18 @@ def __init__(
super().__init__(
es_settings,
update_settings,
config,
storage,
config.runpath_file,
Path(config.user_config_file),
config.env_vars,
config.env_pr_fm_step,
config.model_config,
queue_config,
config.forward_model_steps,
status_queue,
config.substitutions,
config.ert_templates,
config.hooked_workflows,
active_realizations=active_realizations,
start_iteration=0,
total_iterations=2,
Expand All @@ -57,6 +66,10 @@ def __init__(

self.support_restart = False

self._parameter_configuration = config.ensemble_config.parameter_configuration
self._observations = config.observations
self._response_configuration = config.ensemble_config.response_configuration

@tracer.start_as_current_span(f"{__name__}.run_experiment")
def run_experiment(
self, evaluator_server_config: EvaluatorServerConfig, restart: bool = False
Expand All @@ -66,9 +79,9 @@ def run_experiment(
self.run_workflows(HookRuntime.PRE_EXPERIMENT)
ensemble_format = self.target_ensemble_format
experiment = self._storage.create_experiment(
parameters=self.ert_config.ensemble_config.parameter_configuration,
observations=self.ert_config.observations,
responses=self.ert_config.ensemble_config.response_configuration,
parameters=self._parameter_configuration,
observations=self._observations,
responses=self._response_configuration,
name=self.experiment_name,
)

Expand Down
12 changes: 11 additions & 1 deletion src/ert/run_models/evaluate_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import UUID

Expand Down Expand Up @@ -47,11 +48,20 @@ def __init__(
self.ensemble = storage.get_ensemble(UUID(ensemble_id))
except KeyError as err:
raise ValueError(f"No ensemble: {ensemble_id}") from err

super().__init__(
config,
storage,
config.runpath_file,
Path(config.user_config_file),
config.env_vars,
config.env_pr_fm_step,
config.model_config,
queue_config,
config.forward_model_steps,
status_queue,
config.substitutions,
config.ert_templates,
config.hooked_workflows,
start_iteration=self.ensemble.iteration,
total_iterations=1,
active_realizations=active_realizations,
Expand Down
30 changes: 21 additions & 9 deletions src/ert/run_models/everest_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,26 @@ def __init__(

storage = open_storage(config.ens_path, mode="w")
status_queue: queue.SimpleQueue[StatusEvents] = queue.SimpleQueue()

super().__init__(
config,
storage,
config.runpath_file,
Path(config.user_config_file),
config.env_vars,
config.env_pr_fm_step,
config.model_config,
config.queue_config,
config.forward_model_steps,
status_queue,
config.substitutions,
config.ert_templates,
config.hooked_workflows,
active_realizations=[], # Set dynamically in run_forward_model()
)
self.support_restart = False
self._parameter_configuration = config.ensemble_config.parameter_configuration
self._parameter_configs = config.ensemble_config.parameter_configs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to keep both these fields? One is a dict, other is a list, could just keep the dict and do .values() or something like that. Also, the second invocation:

            ext_config = self.ert_config.ensemble_config.parameter_configs[control_name]

Here the experiment will always exist, so we could do self._experiment.parameter_configuration to get the parameter configs, and extract the ExtParam ones (should only be those anyway).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be part of a refactor of EnsembleConfig. Both parameter_config and parameter_configuration seems to be used quite a bit around in the code.

self._response_configuration = config.ensemble_config.response_configuration

@classmethod
def create(
Expand Down Expand Up @@ -187,8 +199,8 @@ def run_experiment(
self._eval_server_cfg = evaluator_server_config
self._experiment = self._storage.create_experiment(
name=f"EnOpt@{datetime.datetime.now().strftime('%Y-%m-%d@%H:%M:%S')}",
parameters=self.ert_config.ensemble_config.parameter_configuration,
responses=self.ert_config.ensemble_config.response_configuration,
parameters=self._parameter_configuration,
responses=self._response_configuration,
)

# Initialize the ropt optimizer:
Expand Down Expand Up @@ -494,7 +506,7 @@ def _check_suffix(
raise KeyError(err_msg)

for control_name, control in controls.items():
ext_config = self.ert_config.ensemble_config.parameter_configs[control_name]
ext_config = self._parameter_configs[control_name]
if isinstance(ext_config, ExtParamConfig):
if len(ext_config) != len(control.keys()):
raise KeyError(
Expand All @@ -515,7 +527,7 @@ def _get_run_args(
evaluator_context: EvaluatorContext,
batch_data: dict[int, Any],
) -> list[RunArg]:
substitutions = self.ert_config.substitutions
substitutions = self._substitutions
substitutions["<BATCH_NAME>"] = ensemble.name
self.active_realizations = [True] * len(batch_data)
for sim_id, control_idx in enumerate(batch_data.keys()):
Expand All @@ -525,11 +537,11 @@ def _get_run_args(
]
)
run_paths = Runpaths(
jobname_format=self.ert_config.model_config.jobname_format_string,
runpath_format=self.ert_config.model_config.runpath_format_string,
filename=str(self.ert_config.runpath_file),
jobname_format=self._model_config.jobname_format_string,
runpath_format=self._model_config.runpath_format_string,
filename=str(self._runpath_file),
substitutions=substitutions,
eclbase=self.ert_config.model_config.eclbase_format_string,
eclbase=self._model_config.eclbase_format_string,
)
return create_run_arguments(
run_paths,
Expand Down
Loading
Loading