diff --git a/src/ert/_c_wrappers/analysis/analysis_module.py b/src/ert/_c_wrappers/analysis/analysis_module.py index 0734de53f58..d8b53dd8dd3 100644 --- a/src/ert/_c_wrappers/analysis/analysis_module.py +++ b/src/ert/_c_wrappers/analysis/analysis_module.py @@ -21,6 +21,8 @@ class VariableInfo(TypedDict): DEFAULT_IES_DEC_STEPLENGTH = 2.50 DEFAULT_ENKF_TRUNCATION = 0.98 DEFAULT_IES_INVERSION = 0 +DEFAULT_LOCALIZATION = False +DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD = 0.4 class AnalysisMode(str, Enum): @@ -46,6 +48,18 @@ def get_mode_variables(mode: AnalysisMode) -> Dict[str, "VariableInfo"]: "step": 0.01, "labelname": "Singular value truncation", }, + "LOCALIZATION": { + "type": bool, + "value": DEFAULT_LOCALIZATION, + "labelname": "Switch for adaptive localization", + }, + "LOCALIZATION_CORRELATION_THRESHOLD": { + "type": float, + "min": 0.0, + "value": DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD, + "max": 1.0, + "labelname": "Threshold defining high correlation", + }, } ies_variables: Dict[str, "VariableInfo"] = { "IES_MAX_STEPLENGTH": { @@ -152,30 +166,39 @@ def set_var(self, var_name: str, value: Union[float, int, bool, str]): self.handle_special_key_set(var_name, value) elif var_name in self._variables: var = self._variables[var_name] - try: - new_value = var["type"](value) - if new_value > var["max"]: - var["value"] = var["max"] - logger.warning( - f"New value {new_value} for key" - f" {var_name} is out of [{var['min']}, {var['max']}] " - f"using max value {var['max']}" + + if var["type"] is not bool: + try: + new_value = var["type"](value) + if new_value > var["max"]: + var["value"] = var["max"] + logger.warning( + f"New value {new_value} for key" + f" {var_name} is out of [{var['min']}, {var['max']}] " + f"using max value {var['max']}" + ) + elif new_value < var["min"]: + var["value"] = var["min"] + logger.warning( + f"New value {new_value} for key" + f" {var_name} is out of [{var['min']}, {var['max']}] " + f"using min value {var['min']}" + ) + else: + var["value"] = new_value + + except ValueError: + raise ValueError( + f"Variable {var_name} expected type {var['type']}" + f" received value `{value}` of type `{type(value)}`" ) - elif new_value < var["min"]: - var["value"] = var["min"] - logger.warning( - f"New value {new_value} for key" - f" {var_name} is out of [{var['min']}, {var['max']}] " - f"using min value {var['min']}" + else: + if not isinstance(var["value"], bool): + raise ValueError( + f"Variable {var_name} expected type {var['type']}" + f" received value `{value}` of type `{type(value)}`" ) - else: - var["value"] = new_value - - except ValueError: - raise ValueError( - f"Variable {var_name} expected type {var['type']}" - f" received value `{value}` of type `{type(value)}`" - ) + var["value"] = var["type"](value) else: raise KeyError(f"Variable {var_name} not found in module") @@ -190,6 +213,12 @@ def inversion(self, value): def get_truncation(self) -> float: return self.get_variable_value("ENKF_TRUNCATION") + def localization(self) -> bool: + return self.get_variable_value("LOCALIZATION") + + def localization_correlation_threshold(self) -> float: + return self.get_variable_value("LOCALIZATION_CORRELATION_THRESHOLD") + def get_steplength(self, iteration_nr: int) -> float: """ This is an implementation of Eq. (49), which calculates a suitable diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index 60e5e32495f..76a3b1ebf9f 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -5,6 +5,7 @@ import iterative_ensemble_smoother as ies import numpy as np +import numpy.typing as npt from iterative_ensemble_smoother.experimental import ( ensemble_smoother_update_step_row_scaling, ) @@ -131,6 +132,29 @@ def _create_temporary_parameter_storage( return temporary_storage +def correlated_parameter_response_pairs( + A: npt.NDArray[np.float_], Y: npt.NDArray[np.float_], correlation_threshold: float +) -> npt.NDArray[np.int_]: + N = A.shape[1] + Y_prime = Y - Y.mean(axis=1, keepdims=True) + C_YY = Y_prime @ Y_prime.T / (N - 1) + Sigma_Y = np.diag(np.sqrt(np.diag(C_YY))) + + A_prime = A - A.mean(axis=1, keepdims=True) + C_AA = A_prime @ A_prime.T / (N - 1) + + # State-measurement covariance matrix + C_AY = A_prime @ Y_prime.T / (N - 1) + Sigma_A = np.diag(np.sqrt(np.diag(C_AA))) + + # State-measurement correlation matrix + c_AY = np.linalg.inv(Sigma_A) @ C_AY @ np.linalg.inv(Sigma_Y) + + # _, corr_idx_Y = np.where(np.triu(np.abs(c_AY)) > correlation_threshold) + _, corr_idx_Y = np.where(np.abs(c_AY) > correlation_threshold) + return corr_idx_Y + + def analysis_ES( updatestep: "UpdateConfiguration", obs: "EnkfObs", @@ -154,7 +178,7 @@ def analysis_ES( # Looping over local analysis update_step for update_step in updatestep: - S, observation_handle = update.load_observations_and_responses( + Y, observation_handle = update.load_observations_and_responses( source_fs, obs, alpha, @@ -178,21 +202,52 @@ def analysis_ES( A_with_rowscaling = _get_row_scaling_A_matrices( temp_storage, update_step.row_scaling_parameters ) - noise = rng.standard_normal(size=(len(observation_values), S.shape[1])) + if A is not None: - A = ies.ensemble_smoother_update_step( - S, - A, - observation_errors, - observation_values, - noise, - module.get_truncation(), - ies.InversionType(module.inversion), - ) - _save_to_temporary_storage(temp_storage, update_step.parameters, A) + if module.localization(): + A_ES_loc = [] + for i in range(A.shape[0]): + N = A.shape[1] + A_chunk = A[i, :].reshape(1, N) + corr_idx_Y = correlated_parameter_response_pairs( + A_chunk, Y, module.localization_correlation_threshold() + ) + Y_loc = Y[corr_idx_Y, :] + observation_errors_loc = observation_errors[corr_idx_Y] + observation_values_loc = observation_values[corr_idx_Y] + noise = rng.standard_normal( + size=(len(observation_values_loc), Y.shape[1]) + ) + + A_loc = ies.ensemble_smoother_update_step( + Y_loc, + A, + observation_errors_loc, + observation_values_loc, + noise, + module.get_truncation(), + ies.InversionType(module.inversion), + ) + A_ES_loc.append(A_loc) + _save_to_temporary_storage( + temp_storage, update_step.parameters, np.vstack(A_loc) + ) + else: + noise = rng.standard_normal(size=(len(observation_values), Y.shape[1])) + A = ies.ensemble_smoother_update_step( + Y, + A, + observation_errors, + observation_values, + noise, + module.get_truncation(), + ies.InversionType(module.inversion), + ) + _save_to_temporary_storage(temp_storage, update_step.parameters, A) + if A_with_rowscaling: A_with_rowscaling = ensemble_smoother_update_step_row_scaling( - S, + Y, A_with_rowscaling, observation_errors, observation_values, @@ -235,7 +290,7 @@ def analysis_IES( # Looping over local analysis update_step for update_step in updatestep: - S, observation_handle = update.load_observations_and_responses( + Y, observation_handle = update.load_observations_and_responses( source_fs, obs, alpha, @@ -258,19 +313,41 @@ def analysis_IES( A = _get_A_matrix(temp_storage, update_step.parameters) - noise = rng.standard_normal(size=(len(observation_values), S.shape[1])) - A = iterative_ensemble_smoother.update_step( - S, - A, - observation_errors, - observation_values, - noise, - ensemble_mask=np.array(ens_mask), - observation_mask=observation_mask, - inversion=ies.InversionType(module.inversion), - truncation=module.get_truncation(), - ) - _save_to_temporary_storage(temp_storage, update_step.parameters, A) + if module.localization(): + corr_idx_Y = correlated_parameter_response_pairs( + A, Y, module.localization_correlation_threshold() + ) + Y_loc = Y[corr_idx_Y, :] + observation_errors_loc = observation_errors[corr_idx_Y] + observation_values_loc = observation_values[corr_idx_Y] + noise = rng.standard_normal(size=(len(observation_values_loc), Y.shape[1])) + + A_loc = iterative_ensemble_smoother.update_step( + Y_loc, + A, + observation_errors_loc, + observation_values_loc, + noise, + ensemble_mask=np.array(ens_mask), + observation_mask=observation_mask, + inversion=ies.InversionType(module.inversion), + truncation=module.get_truncation(), + ) + _save_to_temporary_storage(temp_storage, update_step.parameters, A_loc) + else: + noise = rng.standard_normal(size=(len(observation_values), Y.shape[1])) + A = iterative_ensemble_smoother.update_step( + Y, + A, + observation_errors, + observation_values, + noise, + ensemble_mask=np.array(ens_mask), + observation_mask=observation_mask, + inversion=ies.InversionType(module.inversion), + truncation=module.get_truncation(), + ) + _save_to_temporary_storage(temp_storage, update_step.parameters, A) _save_temporary_storage_to_disk( target_fs, ensemble_config, temp_storage, iens_active_index diff --git a/test-data/poly_example/poly.ert b/test-data/poly_example/poly.ert index 6551792857c..850826d7854 100644 --- a/test-data/poly_example/poly.ert +++ b/test-data/poly_example/poly.ert @@ -1,5 +1,7 @@ JOBNAME poly_%d +ANALYSIS_SET_VAR STD_ENKF LOCALIZATION True + QUEUE_SYSTEM LOCAL QUEUE_OPTION LOCAL MAX_RUNNING 50 diff --git a/test-data/snake_oil/snake_oil.ert b/test-data/snake_oil/snake_oil.ert index f72e60ec4cf..5f24c6cb141 100644 --- a/test-data/snake_oil/snake_oil.ert +++ b/test-data/snake_oil/snake_oil.ert @@ -5,6 +5,9 @@ NUM_REALIZATIONS 25 ANALYSIS_SET_VAR IES_ENKF IES_INVERSION 1 ANALYSIS_SET_VAR STD_ENKF IES_INVERSION 1 +ANALYSIS_SET_VAR STD_ENKF LOCALIZATION True +ANALYSIS_SET_VAR STD_ENKF LOCALIZATION_CORRELATION_THRESHOLD 0.0 + DEFINE storage/ RANDOM_SEED 3593114179000630026631423308983283277868 diff --git a/tests/unit_tests/c_wrappers/res/analysis/test_analysis_module.py b/tests/unit_tests/c_wrappers/res/analysis/test_analysis_module.py index c371ff14412..778829b2f8b 100644 --- a/tests/unit_tests/c_wrappers/res/analysis/test_analysis_module.py +++ b/tests/unit_tests/c_wrappers/res/analysis/test_analysis_module.py @@ -7,6 +7,8 @@ DEFAULT_IES_INVERSION, DEFAULT_IES_MAX_STEPLENGTH, DEFAULT_IES_MIN_STEPLENGTH, + DEFAULT_LOCALIZATION, + DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD, get_mode_variables, ) @@ -21,12 +23,16 @@ def test_analysis_module_default_values(): "IES_DEC_STEPLENGTH": DEFAULT_IES_DEC_STEPLENGTH, "IES_INVERSION": DEFAULT_IES_INVERSION, "ENKF_TRUNCATION": DEFAULT_ENKF_TRUNCATION, + "LOCALIZATION": DEFAULT_LOCALIZATION, + "LOCALIZATION_CORRELATION_THRESHOLD": DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD, # noqa } es_am = AnalysisModule.ens_smoother_module() assert es_am.variable_value_dict() == { "IES_INVERSION": DEFAULT_IES_INVERSION, "ENKF_TRUNCATION": DEFAULT_ENKF_TRUNCATION, + "LOCALIZATION": DEFAULT_LOCALIZATION, + "LOCALIZATION_CORRELATION_THRESHOLD": DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD, # noqa }