Skip to content

Commit

Permalink
Drop row with all False to avoid grp.sum() > 0
Browse files Browse the repository at this point in the history
  • Loading branch information
Feda Curic committed Apr 17, 2023
1 parent 6d06861 commit 5d1837e
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,33 +444,38 @@ def analysis_ES(
)
c_bool = c_AY > correlation_threshold
# Some parameters might be significantly correlated to the exact same
# responses, making up what we cann a parameter group.
# responses, making up what we call a `parameter group``.
# We want to call the update only once per such parameter group
# to speed up computation.
param_groups = np.unique(c_bool, axis=0)

# Drop the parameter group that does not correlate to any responses.
row_with_all_false = np.all(param_groups == False, axis=1)
param_groups = param_groups[~row_with_all_false]

for grp in param_groups:
if grp.sum() > 0:
param_idx = np.where((c_bool == grp).all(axis=1))[0]
# A_chunk = A[param_batch_idx, :][param_idx, :]
X_chunk = temp_storage[parameter.name][param_batch_idx, :][
param_idx, :
]
Y_chunk = Y[grp, :]
observation_errors_loc = observation_errors[grp]
observation_values_loc = observation_values[grp]
smoother.fit(
Y_chunk,
observation_errors_loc,
observation_values_loc,
noise=noise[grp],
truncation=truncation,
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
temp_storage[parameter.name][
param_batch_idx[param_idx], :
] = smoother.update(X_chunk)
# Find the rows matching the parameter group
matching_rows = np.all(c_bool == grp, axis=1)
# Get the indices of the matching rows
row_indices = np.where(matching_rows)[0]
X_chunk = temp_storage[parameter.name][param_batch_idx, :][
row_indices, :
]
Y_chunk = Y[grp, :]
observation_errors_loc = observation_errors[grp]
observation_values_loc = observation_values[grp]
smoother.fit(
Y_chunk,
observation_errors_loc,
observation_values_loc,
noise=noise[grp],
truncation=truncation,
inversion=ies.InversionType(module.inversion),
param_ensemble=param_ensemble,
)
temp_storage[parameter.name][
param_batch_idx[row_indices], :
] = smoother.update(X_chunk)
else:
for parameter in update_step.parameters:
smoother.fit(
Expand Down

0 comments on commit 5d1837e

Please sign in to comment.