Skip to content

Commit

Permalink
Property test parameter-reponse corr function
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Dec 1, 2022
1 parent e7411d9 commit 12ac2fd
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def _create_temporary_parameter_storage(
return temporary_storage


def correlated_parameter_response_pairs(
def _correlated_parameter_response_pairs(
A: npt.NDArray[np.float_], Y: npt.NDArray[np.float_], correlation_threshold: float
) -> npt.NDArray[np.int_]:
N = A.shape[1]
Expand Down Expand Up @@ -210,7 +210,7 @@ def analysis_ES(
p = A.shape[0]
for i in range(A.shape[0]):
A_chunk = A[i, :].reshape(-1, N)
corr_idx_Y = correlated_parameter_response_pairs(
corr_idx_Y = _correlated_parameter_response_pairs(
A_chunk,
Y,
module.localization_correlation_threshold(),
Expand Down
45 changes: 44 additions & 1 deletion tests/unit_tests/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np

rng = np.random.default_rng()
from scipy.linalg import toeplitz

import pandas as pd
import pytest
from iterative_ensemble_smoother import IterativeEnsembleSmoother
Expand All @@ -13,7 +15,10 @@
from ert.__main__ import ert_parser
from ert._c_wrappers.enkf import EnKFMain, EnkfNode, NodeId, ResConfig, RunContext
from ert.analysis import ErtAnalysisError, ESUpdate
from ert.analysis._es_update import _create_temporary_parameter_storage
from ert.analysis._es_update import (
_create_temporary_parameter_storage,
_correlated_parameter_response_pairs,
)
from ert.cli import ENSEMBLE_EXPERIMENT_MODE, ENSEMBLE_SMOOTHER_MODE
from ert.cli.main import run_cli

Expand Down Expand Up @@ -492,3 +497,41 @@ def test_update_multiple_param(copy_case):
# https://en.wikipedia.org/wiki/Variance#For_vector-valued_random_variables
for prior_name, prior_data in prior.items():
assert np.trace(np.cov(posterior[prior_name])) < np.trace(np.cov(prior_data))


def test_correlated_parameter_response_pair():
p = 4
m = 2
N = 1000
rho = 0.9
# Correlation matrix or AR(1) model
R = toeplitz([rho**i for i in range(p + m)])

# Get correlated samples from multivariate normal distribution
Z = rng.standard_normal(size=(p + m, N))
X = np.linalg.cholesky(R) @ Z

assert (
len(
_correlated_parameter_response_pairs(
X[0, :].reshape(-1, N), X[1, :].reshape(-1, N), rho - 0.05
)
)
== 1
)
assert (
len(
_correlated_parameter_response_pairs(
X[0, :].reshape(-1, N), X[1:4, :], rho**2 - 0.05
)
)
== 2
)
assert (
len(
_correlated_parameter_response_pairs(
X[0, :].reshape(-1, N), X[1:6, :], rho**3 - 0.05
)
)
== 3
)

0 comments on commit 12ac2fd

Please sign in to comment.