Skip to content

Commit

Permalink
Remove multi-realization save logic
Browse files Browse the repository at this point in the history
(testing)
  • Loading branch information
yngve-sk committed Nov 26, 2024
1 parent c0cb788 commit bb03186
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 29 deletions.
5 changes: 3 additions & 2 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,9 @@ def _copy_unupdated_parameters(

# Copy the non-updated parameter groups from source to target for each active realization
for parameter_group in not_updated_parameter_groups:
ds = source_ensemble.load_parameters(parameter_group, iens_active_index)
target_ensemble.save_parameters(parameter_group, iens_active_index, ds)
for realization in iens_active_index:
ds = source_ensemble.load_parameters(parameter_group, realization)
target_ensemble.save_parameters(parameter_group, realization, ds)


def analysis_ES(
Expand Down
35 changes: 9 additions & 26 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,10 @@ def _load_single_dataset(
def _load_dataset(
self,
group: str,
realizations: Union[int, npt.NDArray[np.int_], None],
realizations: Union[int, np.int64, npt.NDArray[np.int_], None],
) -> xr.Dataset:
if isinstance(realizations, int):
return self._load_single_dataset(group, realizations).isel(
if isinstance(realizations, (int, np.int64)):
return self._load_single_dataset(group, int(realizations)).isel(
realizations=0, drop=True
)

Expand Down Expand Up @@ -792,7 +792,7 @@ def load_all_gen_kw_data(
def save_parameters(
self,
group: str,
realization: Union[int, npt.NDArray[np.int_]],
realization: int,
dataset: xr.Dataset,
) -> None:
"""
Expand Down Expand Up @@ -820,30 +820,13 @@ def save_parameters(
if group not in self.experiment.parameter_configuration:
raise ValueError(f"{group} is not registered to the experiment.")

# Convert to numpy array if it's an integer
realizations: npt.NDArray[np.int_] = (
np.array([realization])
if isinstance(realization, (int, np.integer))
else np.asarray(realization)
)

if realizations.size > 1 and "realizations" not in dataset.dims:
raise ValueError(
"Dataset must have 'realizations' dimension when saving multiple realizations"
)

path = self._realization_dir(realization) / f"{_escape_filename(group)}.nc"
path.parent.mkdir(exist_ok=True)
if "realizations" in dataset.dims:
dataset = dataset.sel(realizations=realizations)
for real, data_to_save in dataset.groupby("realizations"):
path = self._realization_dir(real) / f"{_escape_filename(group)}.nc"
path.parent.mkdir(exist_ok=True)
self._storage._to_netcdf_transaction(path, data_to_save)
data_to_save = dataset.sel(realizations=[realization])
else:
for real in realizations:
path = self._realization_dir(real) / f"{_escape_filename(group)}.nc"
path.parent.mkdir(exist_ok=True)
data_to_save = dataset.expand_dims(realizations=[real])
self._storage._to_netcdf_transaction(path, data_to_save)
data_to_save = dataset.expand_dims(realizations=[realization])
self._storage._to_netcdf_transaction(path, data_to_save)

@require_write
def save_response(
Expand Down
3 changes: 2 additions & 1 deletion tests/ert/ui_tests/cli/test_field_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ def test_field_param_update_using_heat_equation_zero_var_params_and_adaptive_loc
name="prior-zero-var",
)
cond["values"][:, :, :5, 0] = 1.0
new_prior.save_parameters("COND", range(prior.ensemble_size), cond)
for real in range(prior.ensemble_size):
new_prior.save_parameters("COND", real, cond)

# Copy responses from existing prior to new prior.
# Note that we ideally should generate new responses by running the
Expand Down

0 comments on commit bb03186

Please sign in to comment.