Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run ensemble experiment with design matrix #8941

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 69 additions & 13 deletions src/ert/config/design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
from .parsing import ConfigValidationError, ErrorInfo

if TYPE_CHECKING:
from ert.config import (
ParameterConfig,
)
from ert.config import ParameterConfig

DESIGN_MATRIX_GROUP = "DESIGN_MATRIX"

Expand All @@ -28,10 +26,17 @@ class DesignMatrix:
default_sheet: str

def __post_init__(self) -> None:
self.num_realizations: int | None = None
self.active_realizations: list[bool] | None = None
self.design_matrix_df: pd.DataFrame | None = None
self.parameter_configuration: dict[str, ParameterConfig] | None = None
try:
(
self.active_realizations,
self.design_matrix_df,
self.parameter_configuration,
) = self.read_design_matrix()
except (ValueError, AttributeError) as exc:
raise ConfigValidationError.with_context(
f"Error reading design matrix {self.xls_filename}: {exc}",
str(self.xls_filename),
) from exc

@classmethod
def from_config_list(cls, config_list: list[str]) -> DesignMatrix:
Expand Down Expand Up @@ -73,9 +78,60 @@ def from_config_list(cls, config_list: list[str]) -> DesignMatrix:
default_sheet=default_sheet,
)

def merge_with_existing_parameters(
self, existing_parameters: list[ParameterConfig]
) -> tuple[list[ParameterConfig], ParameterConfig | None]:
"""
This method merges the design matrix parameters with the existing parameters and
returns the new list of existing parameters, wherein we drop GEN_KW group having a full overlap with the design matrix group.
GEN_KW group that was dropped will acquire a new name from the design matrix group.
Additionally, the ParameterConfig which is the design matrix group is returned separately.
Args:
existing_parameters (List[ParameterConfig]): List of existing parameters
Raises:
ConfigValidationError: If there is a partial overlap between the design matrix group and any existing GEN_KW group
Returns:
tuple[List[ParameterConfig], ParameterConfig]: List of existing parameters and the dedicated design matrix group
"""

new_param_config: list[ParameterConfig] = []

design_parameter_group = self.parameter_configuration[DESIGN_MATRIX_GROUP]
design_keys = []
if isinstance(design_parameter_group, GenKwConfig):
larsevj marked this conversation as resolved.
Show resolved Hide resolved
design_keys = [e.name for e in design_parameter_group.transform_functions]

design_group_added = False
for parameter_group in existing_parameters:
if not isinstance(parameter_group, GenKwConfig):
new_param_config += [parameter_group]
continue
existing_keys = [e.name for e in parameter_group.transform_functions]
if set(existing_keys) == set(design_keys):
larsevj marked this conversation as resolved.
Show resolved Hide resolved
if design_group_added:
raise ConfigValidationError(
"Multiple overlapping groups with design matrix found in existing parameters!\n"
f"{design_parameter_group.name} and {parameter_group.name}"
)

design_parameter_group.name = parameter_group.name
design_group_added = True
elif set(design_keys) & set(existing_keys):
raise ConfigValidationError(
"Overlapping parameter names found in design matrix!\n"
f"{DESIGN_MATRIX_GROUP}:{design_keys}\n{parameter_group.name}:{existing_keys}"
"\nThey need to much exactly or not at all."
)
else:
new_param_config += [parameter_group]
return new_param_config, design_parameter_group

def read_design_matrix(
self,
) -> None:
) -> tuple[list[bool], pd.DataFrame, dict[str, ParameterConfig]]:
# Read the parameter names (first row) as strings to prevent pandas from modifying them.
# This ensures that duplicate or empty column names are preserved exactly as they appear in the Excel sheet.
# By doing this, we can properly validate variable names, including detecting duplicates or missing names.
Expand Down Expand Up @@ -139,11 +195,11 @@ def read_design_matrix(
[[DESIGN_MATRIX_GROUP], design_matrix_df.columns]
)
reals = design_matrix_df.index.tolist()
self.num_realizations = len(reals)
self.active_realizations = [x in reals for x in range(max(reals) + 1)]

self.design_matrix_df = design_matrix_df
self.parameter_configuration = parameter_configuration
return (
[x in reals for x in range(max(reals) + 1)],
design_matrix_df,
parameter_configuration,
)

@staticmethod
def _read_excel(
Expand Down
39 changes: 31 additions & 8 deletions src/ert/enkf_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,17 @@
from typing import TYPE_CHECKING, Any

import orjson
import pandas as pd
import xarray as xr
from numpy.random import SeedSequence

from ert.config.ert_config import forward_model_data_to_json
from ert.config.forward_model_step import ForwardModelStep
from ert.config.model_config import ModelConfig
from ert.substitutions import Substitutions, substitute_runpath_name

from .config import (
ExtParamConfig,
Field,
GenKwConfig,
ParameterConfig,
SurfaceConfig,
)
from .config import ExtParamConfig, Field, GenKwConfig, ParameterConfig, SurfaceConfig
from .config.design_matrix import DESIGN_MATRIX_GROUP
from .run_arg import RunArg
from .runpaths import Runpaths

Expand Down Expand Up @@ -53,7 +50,10 @@ def _value_export_txt(
with path.open("w") as f:
for key, param_map in values.items():
for param, value in param_map.items():
print(f"{key}:{param} {value:g}", file=f)
if isinstance(value, (int | float)):
print(f"{key}:{param} {value:g}", file=f)
else:
print(f"{key}:{param} {value}", file=f)


def _value_export_json(
Expand Down Expand Up @@ -156,6 +156,29 @@ def _seed_sequence(seed: int | None) -> int:
return int_seed


def save_design_matrix_to_ensemble(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be a method in the ensemble class? my_ensemble.save_design_matrix(...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already ensemble.save_parameters, which represents the interface for storing parameters. Not sure if it would bring any value having it there as it is a special "class" of parameters.

design_matrix_df: pd.DataFrame,
ensemble: Ensemble,
active_realizations: Iterable[int],
design_group_name: str = DESIGN_MATRIX_GROUP,
) -> None:
assert not design_matrix_df.empty
for realization_nr in active_realizations:
row = design_matrix_df.loc[realization_nr][DESIGN_MATRIX_GROUP]
ds = xr.Dataset(
{
"values": ("names", list(row.values)),
"transformed_values": ("names", list(row.values)),
"names": list(row.keys()),
}
)
ensemble.save_parameters(
design_group_name,
realization_nr,
ds,
)


def sample_prior(
ensemble: Ensemble,
active_realizations: Iterable[int],
Expand Down
30 changes: 12 additions & 18 deletions src/ert/gui/simulation/ensemble_experiment_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ert.gui.tools.design_matrix.design_matrix_panel import DesignMatrixPanel
from ert.mode_definitions import ENSEMBLE_EXPERIMENT_MODE
from ert.run_models import EnsembleExperiment
from ert.validation import RangeStringArgument
from ert.validation import ActiveRange, RangeStringArgument
from ert.validation.proper_name_argument import ExperimentValidation, ProperNameArgument

from .experiment_config_panel import ExperimentConfigPanel
Expand Down Expand Up @@ -85,6 +85,9 @@ def __init__(

design_matrix = analysis_config.design_matrix
if design_matrix is not None:
self._active_realizations_field.setText(
ActiveRange(design_matrix.active_realizations).rangestring
)
show_dm_param_button = QPushButton("Show parameters")
show_dm_param_button.setObjectName("show-dm-parameters")
show_dm_param_button.setMinimumWidth(50)
Expand Down Expand Up @@ -113,23 +116,14 @@ def __init__(
self.notifier.ertChanged.connect(self._update_experiment_name_placeholder)

def on_show_dm_params_clicked(self, design_matrix: DesignMatrix) -> None:
assert design_matrix is not None

if design_matrix.design_matrix_df is None:
design_matrix.read_design_matrix()

if (
design_matrix.design_matrix_df is not None
and not design_matrix.design_matrix_df.empty
):
viewer = DesignMatrixPanel(
design_matrix.design_matrix_df,
design_matrix.xls_filename.name,
)
viewer.setMinimumHeight(500)
viewer.setMinimumWidth(1000)
viewer.adjustSize()
viewer.exec_()
viewer = DesignMatrixPanel(
design_matrix.design_matrix_df,
design_matrix.xls_filename.name,
)
viewer.setMinimumHeight(500)
viewer.setMinimumWidth(1000)
viewer.adjustSize()
viewer.exec_()

@Slot(ExperimentConfigPanel)
def experimentTypeChanged(self, w: ExperimentConfigPanel) -> None:
Expand Down
33 changes: 30 additions & 3 deletions src/ert/run_models/ensemble_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@

import numpy as np

from ert.enkf_main import sample_prior
from ert.config import ConfigValidationError
from ert.enkf_main import sample_prior, save_design_matrix_to_ensemble
from ert.ensemble_evaluator import EvaluatorServerConfig
from ert.storage import Ensemble, Experiment, Storage
from ert.trace import tracer

from ..run_arg import create_run_arguments
from .base_run_model import BaseRunModel, StatusEvents
from .base_run_model import BaseRunModel, ErtRunError, StatusEvents

if TYPE_CHECKING:
from ert.config import ErtConfig, QueueConfig
Expand Down Expand Up @@ -64,10 +65,27 @@ def run_experiment(
) -> None:
self.log_at_startup()
self.restart = restart
# If design matrix is present, we try to merge design matrix parameters
# to the experiment parameters and set new active realizations
parameters_config = self.ert_config.ensemble_config.parameter_configuration
design_matrix = self.ert_config.analysis_config.design_matrix
design_matrix_group = None
if design_matrix is not None:
try:
parameters_config, design_matrix_group = (
design_matrix.merge_with_existing_parameters(parameters_config)
)
except ConfigValidationError as exc:
raise ErtRunError(str(exc)) from exc

if not restart:
self.experiment = self._storage.create_experiment(
name=self.experiment_name,
parameters=self.ert_config.ensemble_config.parameter_configuration,
parameters=(
[*parameters_config, design_matrix_group]
larsevj marked this conversation as resolved.
Show resolved Hide resolved
if design_matrix_group is not None
else parameters_config
),
observations=self.ert_config.observations,
responses=self.ert_config.ensemble_config.response_configuration,
)
Expand All @@ -90,12 +108,21 @@ def run_experiment(
np.array(self.active_realizations, dtype=bool),
ensemble=self.ensemble,
)

sample_prior(
self.ensemble,
np.where(self.active_realizations)[0],
random_seed=self.random_seed,
)

if design_matrix_group is not None and design_matrix is not None:
save_design_matrix_to_ensemble(
design_matrix.design_matrix_df,
self.ensemble,
np.where(self.active_realizations)[0],
design_matrix_group.name,
)

self._evaluate_and_postprocess(
run_args,
self.ensemble,
Expand Down
17 changes: 12 additions & 5 deletions src/ert/run_models/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,20 @@ def _setup_ensemble_experiment(
args: Namespace,
status_queue: SimpleQueue[StatusEvents],
) -> EnsembleExperiment:
active_realizations = _realizations(args, config.model_config.num_realizations)
active_realizations = _realizations(
args, config.model_config.num_realizations
).tolist()
if (
config.analysis_config.design_matrix is not None
and config.analysis_config.design_matrix.active_realizations is not None
):
active_realizations = config.analysis_config.design_matrix.active_realizations
experiment_name = args.experiment_name
assert experiment_name is not None

return EnsembleExperiment(
random_seed=config.random_seed,
active_realizations=active_realizations.tolist(),
active_realizations=active_realizations,
ensemble_name=args.current_ensemble,
minimum_required_realizations=config.analysis_config.minimum_required_realizations,
experiment_name=experiment_name,
Expand Down Expand Up @@ -271,9 +278,9 @@ def _setup_iterative_ensemble_smoother(
random_seed=config.random_seed,
active_realizations=active_realizations.tolist(),
target_ensemble=_iterative_ensemble_format(args),
number_of_iterations=int(args.num_iterations)
if args.num_iterations is not None
else 4,
number_of_iterations=(
int(args.num_iterations) if args.num_iterations is not None else 4
),
minimum_required_realizations=config.analysis_config.minimum_required_realizations,
num_retries_per_iter=4,
experiment_name=experiment_name,
Expand Down
Loading
Loading