Skip to content

Commit

Permalink
Use theory to calculate default corr threshold
Browse files Browse the repository at this point in the history
Theory suggests a correlation threshold of 3/sqrt(ensemble_size)
  • Loading branch information
dafeda committed Dec 13, 2022
1 parent b1bfd15 commit a946297
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/ert/_c_wrappers/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .analysis_module import AnalysisMode, AnalysisModule
from .analysis_module import AnalysisMode, AnalysisModule, correlation_threshold

__all__ = ["AnalysisModule", "AnalysisMode"]
__all__ = ["AnalysisModule", "AnalysisMode", "correlation_threshold"]
24 changes: 21 additions & 3 deletions src/ert/_c_wrappers/analysis/analysis_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import math
from enum import Enum
from typing import TYPE_CHECKING, Dict, List, Type, TypedDict, Union

Expand All @@ -22,7 +23,22 @@ class VariableInfo(TypedDict):
DEFAULT_ENKF_TRUNCATION = 0.98
DEFAULT_IES_INVERSION = 0
DEFAULT_LOCALIZATION = False
DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD = 0.4
# Default threshold is a function of ensemble size which is not available here.
DEFAULT_LOCALIZATION_CORRELATION_THRESHOLD = -1


def correlation_threshold(ensemble_size: int, user_defined_threshold: float) -> float:
"""Decides whether or not to use user-defined or default threshold.
Default threshold taken from luo2022,
Continuous Hyper-parameter OPtimization (CHOP) in an ensemble Kalman filter
Section 2.3 - Localization in the CHOP problem
"""
default_threshold = 3 / math.sqrt(ensemble_size)
if user_defined_threshold == -1:
return default_threshold

return user_defined_threshold


class AnalysisMode(str, Enum):
Expand Down Expand Up @@ -222,8 +238,10 @@ def get_truncation(self) -> float:
def localization(self) -> bool:
return self.get_variable_value("LOCALIZATION")

def localization_correlation_threshold(self) -> float:
return self.get_variable_value("LOCALIZATION_CORRELATION_THRESHOLD")
def localization_correlation_threshold(self, ensemble_size: int) -> float:
return correlation_threshold(
ensemble_size, self.get_variable_value("LOCALIZATION_CORRELATION_THRESHOLD")
)

def get_steplength(self, iteration_nr: int) -> float:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def analysis_ES(
responses_to_keep = _correlated_parameter_response_pairs(
A_chunk,
Y,
module.localization_correlation_threshold(),
module.localization_correlation_threshold(ensemble_size),
)
Y_loc = Y[responses_to_keep, :]
observation_errors_loc = observation_errors[responses_to_keep]
Expand Down
8 changes: 7 additions & 1 deletion src/ert/gui/ertwidgets/analysismodulevariablespanel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from functools import partial

from qtpy.QtWidgets import (
QCheckBox,
QDoubleSpinBox,
Expand All @@ -15,6 +14,7 @@
AnalysisModuleVariablesModel,
)
from ert.libres_facade import LibresFacade
from ert._c_wrappers.analysis import correlation_threshold


class AnalysisModuleVariablesPanel(QWidget):
Expand All @@ -41,10 +41,16 @@ def __init__(self, analysis_module_name: str, facade: LibresFacade):
variable_type = analysis_module_variables_model.getVariableType(
variable_name
)

variable_value = analysis_module_variables_model.getVariableValue(
self.facade, self._analysis_module_name, variable_name
)

if variable_name == "LOCALIZATION_CORRELATION_THRESHOLD":
variable_value = correlation_threshold(
self.facade.get_ensemble_size(), variable_value
)

label_name = analysis_module_variables_model.getVariableLabelName(
variable_name
)
Expand Down

0 comments on commit a946297

Please sign in to comment.