Skip to content

Commit

Permalink
Merge genkw and design matrix params
Browse files Browse the repository at this point in the history
  • Loading branch information
larsevj committed Sep 25, 2024
1 parent f36ebbb commit f988b35
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def headerData(
orientation == Qt.Orientation.Horizontal
and role == Qt.ItemDataRole.DisplayRole
):
return self._df.columns[section]
return "\n".join(self._df.columns[section])
return QtCore.QVariant()


Expand Down
98 changes: 63 additions & 35 deletions src/ert/sensitivity_analysis/design_matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import defaultdict
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -49,18 +50,40 @@ def read_design_matrix(
# ignoring errors here is deprecated in pandas, should find another solution
# design_matrix_sheet = design_matrix_sheet.apply(pd.to_numeric, errors="ignore")

existing_parameters = {
param.name
parameter_groups = defaultdict(list)
parameter_map = []
all_param_configs = [
param_group
for param_group in ert_config.ensemble_config.parameter_configuration
if isinstance(param_group, GenKwConfig)
for param in param_group.transform_function_definitions
}
intersect = existing_parameters.intersection(set(design_matrix_sheet.columns))
# This errors if parameters exists already, this behaviour should be discussed.
if intersect:
msg = "The following parameters were specified both"
f"as gen_kw and in the design matrix: {intersect}"
]
errors = {}
for param in design_matrix_sheet.columns:
par_gp = []
for param_group in all_param_configs:
if param in param_group:
par_gp.append(param_group.name)

if not par_gp:
parameter_name = "DESIGN_MATRIX"
parameter_groups[parameter_name].append(param)
parameter_map.append((parameter_name, param))
elif len(par_gp) == 1:
parameter_name = par_gp[0]
parameter_groups[parameter_name].append(param)
parameter_map.append((parameter_name, param))
else:
errors[param] = par_gp

if errors:
msg = ""
for key, value in errors.items():
msg += (
f"The following parameter '{key}' was found in multiple"
f" GenKw parameters groups: {value}."
)
raise ValueError(msg)
design_matrix_sheet.columns = pd.MultiIndex.from_tuples(parameter_map)
return design_matrix_sheet


Expand All @@ -71,30 +94,34 @@ def initialize_parameters(
exp_name: str,
ens_name: str,
) -> LocalEnsemble:
existing_parameters = ert_config.ensemble_config.parameter_configuration
parameters = design_matrix_sheet.columns
transform_function_definitions: list[TransformFunctionDefinition] = []
for param in parameters:
transform_function_definitions.append(
TransformFunctionDefinition(
name=param,
param_name="RAW",
values=[],
existing_parameters = ert_config.ensemble_config.parameter_configs.copy()
for parameter_group in design_matrix_sheet.columns.get_level_values(0).unique():
parameters = design_matrix_sheet[parameter_group].columns
transform_function_definitions: list[TransformFunctionDefinition] = []
for param in parameters:
transform_function_definitions.append(
TransformFunctionDefinition(
name=param,
param_name="RAW",
values=[],
)
)
)
existing_parameters.append(
GenKwConfig(
name=DESIGN_MATRIX_GROUP,
existing = existing_parameters.get(parameter_group)
existing_parameters[parameter_group] = GenKwConfig(
name=parameter_group,
forward_init=False,
template_file=None,
output_file=None,
template_file=existing.template_file
if isinstance(existing, GenKwConfig)
else None,
output_file=existing.output_file
if isinstance(existing, GenKwConfig)
else None,
transform_function_definitions=transform_function_definitions,
update=False,
)
)

experiment = storage.create_experiment(
parameters=existing_parameters,
parameters=list(existing_parameters.values()),
responses=ert_config.ensemble_config.response_configuration,
observations=ert_config.observations,
name=exp_name,
Expand All @@ -105,15 +132,16 @@ def initialize_parameters(
ensemble_size=len(design_matrix_sheet.index),
)
for i in design_matrix_sheet.index:
row: pd.Series = design_matrix_sheet.iloc[i]
ds = xr.Dataset(
{
"values": ("names", list(row.to_numpy())),
"transformed_values": ("names", list(row.to_numpy())),
"names": list(row.keys()),
}
)
ensemble.save_parameters(DESIGN_MATRIX_GROUP, i, ds)
for parameter_group in experiment.parameter_configuration:
row: pd.Series = design_matrix_sheet.iloc[i][parameter_group]
ds = xr.Dataset(
{
"values": ("names", list(row.to_numpy())),
"transformed_values": ("names", list(row.to_numpy())),
"names": list(row.keys()),
}
)
ensemble.save_parameters(parameter_group, i, ds)
return ensemble


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_design_matrix(copy_poly_case):
NUM_REALIZATIONS 100
MIN_REALIZATIONS 1
GEN_KW COEFFS coeff_priors
GEN_DATA POLY_RES RESULT_FILE:poly.out
INSTALL_JOB poly_eval POLY_EVAL
Expand Down

0 comments on commit f988b35

Please sign in to comment.