Skip to content

Commit

Permalink
Return cross correlations for multiple GEN_KWs
Browse files Browse the repository at this point in the history
  • Loading branch information
dafeda committed Dec 11, 2024
1 parent c9d43bc commit c44eac7
Show file tree
Hide file tree
Showing 11 changed files with 506 additions and 487 deletions.
15 changes: 15 additions & 0 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from collections.abc import Callable, Iterable, Sequence
from fnmatch import fnmatch
from itertools import groupby
from typing import (
TYPE_CHECKING,
Generic,
Expand Down Expand Up @@ -168,6 +169,7 @@ def _load_observations_and_responses(
npt.NDArray[np.float64],
tuple[
npt.NDArray[np.float64],
list[str],
npt.NDArray[np.float64],
list[ObservationAndResponseSnapshot],
],
Expand Down Expand Up @@ -315,6 +317,7 @@ def _load_observations_and_responses(

return S[obs_mask], (
observations[obs_mask],
obs_keys[obs_mask],
scaled_errors[obs_mask],
update_snapshot,
)
Expand Down Expand Up @@ -458,6 +461,7 @@ def adaptive_localization_progress_callback(
S,
(
observation_values,
observation_keys,
observation_errors,
update_snapshot,
),
Expand All @@ -474,6 +478,14 @@ def adaptive_localization_progress_callback(
num_obs = len(observation_values)

smoother_snapshot.update_step_snapshots = update_snapshot
# Used as labels for observations in cross-correlation matrix.
# Say we have two observation groups "FOPR" and "WOPR" where "FOPR" has
# 2 responses and "WOPR" has 3.
# In this case we create a list [FOPR_0, FOPR_1, WOPR_0, WOPR_1, WOPR_2]
# as labels for observations.
unique_obs_names = [
f"{k}_{i}" for k, g in groupby(observation_keys) for i, _ in enumerate(list(g))
]

if num_obs == 0:
msg = "No active observations for update step"
Expand Down Expand Up @@ -577,6 +589,8 @@ def correlation_callback(
cross_correlations_,
param_group,
parameter_names[: cross_correlations_.shape[0]],
unique_obs_names,
list(observation_keys),
)
logger.info(
f"Adaptive Localization of {param_group} completed in {(time.time() - start) / 60} minutes"
Expand Down Expand Up @@ -639,6 +653,7 @@ def analysis_IES(
S,
(
observation_values,
_,
observation_errors,
update_snapshot,
),
Expand Down
39 changes: 25 additions & 14 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np
import pandas as pd
import polars as pl
import xarray as xr
from pydantic import BaseModel
from typing_extensions import deprecated
Expand Down Expand Up @@ -560,16 +561,15 @@ def load_parameters(

return self._load_dataset(group, realizations)

def load_cross_correlations(self) -> xr.Dataset:
input_path = self.mount_point / "corr_XY.nc"

def load_cross_correlations(self) -> pl.DataFrame:
input_path = self.mount_point / "corr_XY.parquet"
if not input_path.exists():
raise FileNotFoundError(
f"No cross-correlation data available at '{input_path}'. Make sure to run the update with "
"Adaptive Localization enabled."
)
logger.info("Loading cross correlations")
return xr.open_dataset(input_path, engine="scipy")
return pl.read_parquet(input_path)

@require_write
def save_observation_scaling_factors(self, dataset: polars.DataFrame) -> None:
Expand All @@ -592,17 +592,28 @@ def save_cross_correlations(
cross_correlations: npt.NDArray[np.float64],
param_group: str,
parameter_names: list[str],
unique_obs_names: list[str],
observation_keys: list[str],
) -> None:
data_vars = {
param_group: xr.DataArray(
data=cross_correlations,
dims=["parameter", "response"],
coords={"parameter": parameter_names},
)
}
dataset = xr.Dataset(data_vars)
file_path = os.path.join(self.mount_point, "corr_XY.nc")
self._storage._to_netcdf_transaction(file_path, dataset)
n_responses = cross_correlations.shape[1]
new_df = pl.DataFrame(
{
"param_group": [param_group]
* (len(parameter_names) * len(unique_obs_names)),
"param_name": np.repeat(parameter_names, n_responses),
"obs_group": observation_keys * len(parameter_names),
"obs_name": unique_obs_names * len(parameter_names),
"value": cross_correlations.ravel(),
}
)

file_path = os.path.join(self.mount_point, "corr_XY.parquet")
if os.path.exists(file_path):
existing_df = pl.read_parquet(file_path)
df = pl.concat([existing_df, new_df])
else:
df = new_df
self._storage._to_parquet_transaction(file_path, df)

def load_responses(
self, key: str, realizations: tuple[int, ...]
Expand Down
264 changes: 264 additions & 0 deletions test-data/ert/heat_equation/Plot_correlations.ipynb

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions test-data/ert/heat_equation/config.ert
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ QUEUE_OPTION LOCAL MAX_RUNNING 100

RANDOM_SEED 11223344

ANALYSIS_SET_VAR STD_ENKF LOCALIZATION True
ANALYSIS_SET_VAR STD_ENKF LOCALIZATION_CORRELATION_THRESHOLD 0.1

NUM_REALIZATIONS 100
GRID CASE.EGRID

OBS_CONFIG observations

FIELD COND PARAMETER cond.bgrdecl INIT_FILES:cond.bgrdecl FORWARD_INIT:True

GEN_KW INIT_TEMP_SCALE init_temp_prior.txt
GEN_KW CORR_LENGTH corr_length_prior.txt

GEN_DATA MY_RESPONSE RESULT_FILE:gen_data_%d.out REPORT_STEPS:10,71,132,193,255,316,377,438 INPUT_FORMAT:ASCII

INSTALL_JOB heat_equation HEAT_EQUATION
Expand Down
1 change: 1 addition & 0 deletions test-data/ert/heat_equation/corr_length_prior.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x NORMAL 0.8 0.1
23 changes: 19 additions & 4 deletions test-data/ert/heat_equation/heat_equation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
"""Partial Differential Equations to use as forward models."""

import json
import sys

import geostat
Expand Down Expand Up @@ -51,16 +52,28 @@ def heat_equation(
return u_


def sample_prior_conductivity(ensemble_size, nx, rng):
def sample_prior_conductivity(ensemble_size, nx, rng, corr_length):
mesh = np.meshgrid(np.linspace(0, 1, nx), np.linspace(0, 1, nx))
return np.exp(geostat.gaussian_fields(mesh, rng, ensemble_size, r=0.8))
return np.exp(geostat.gaussian_fields(mesh, rng, ensemble_size, r=corr_length))


def load_parameters(filename):
with open(filename, encoding="utf-8") as f:
return json.load(f)


if __name__ == "__main__":
iens = int(sys.argv[1])
iteration = int(sys.argv[2])
rng = np.random.default_rng(iens)
cond = sample_prior_conductivity(ensemble_size=1, nx=nx, rng=rng).reshape(nx, nx)

parameters = load_parameters("parameters.json")
init_temp_scale = parameters["INIT_TEMP_SCALE"]
corr_length = parameters["CORR_LENGTH"]

cond = sample_prior_conductivity(
ensemble_size=1, nx=nx, rng=rng, corr_length=float(corr_length["x"])
).reshape(nx, nx)

if iteration == 0:
resfo.write(
Expand All @@ -78,7 +91,9 @@ def sample_prior_conductivity(ensemble_size, nx, rng):
# Note that this could be avoided if we used an implicit solver.
dt = dx**2 / (4 * max(np.max(cond), np.max(cond)))

response = heat_equation(u_init, cond, dx, dt, k_start, k_end, rng)
scaled_u_init = u_init * float(init_temp_scale["x"])

response = heat_equation(scaled_u_init, cond, dx, dt, k_start, k_end, rng)

index = sorted((obs.x, obs.y) for obs in obs_coordinates)
for time_step in obs_times:
Expand Down
1 change: 1 addition & 0 deletions test-data/ert/heat_equation/init_temp_prior.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
x UNIFORM 0 1
548 changes: 105 additions & 443 deletions test-data/ert/poly_example/Plot_correlations.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions test-data/ert/snake_oil_field/snake_oil.ert
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ GRID grid/CASE.EGRID
DEFINE <STORAGE> storage/<CONFIG_FILE_BASE>
RANDOM_SEED 3593114179000630026631423308983283277868

ANALYSIS_SET_VAR STD_ENKF LOCALIZATION True
ANALYSIS_SET_VAR STD_ENKF LOCALIZATION_CORRELATION_THRESHOLD 0.1

RUNPATH <STORAGE>/runpath/realization-<IENS>/iter-<ITER>
ENSPATH <STORAGE>/ensemble
JOBNAME SNAKE_OIL_<IENS>
Expand Down
16 changes: 16 additions & 0 deletions tests/ert/ui_tests/cli/analysis/test_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,22 @@ def test_update_multiple_param():
# https://en.wikipedia.org/wiki/Variance#For_vector-valued_random_variables
assert np.trace(np.cov(posterior_array)) < np.trace(np.cov(prior_array))

corr_XY = prior_ensemble.load_cross_correlations()
expected_obs_groups = [obs[0] for obs in ert_config.observation_config]
obs_groups = corr_XY["obs_group"].unique().to_list()
assert sorted(obs_groups) == sorted(expected_obs_groups)
# Check that obs names are created using obs groups
obs_name_starts_with_group = (
corr_XY.with_columns(
polars.col("obs_name")
.str.starts_with(polars.col("obs_group"))
.alias("starts_with_check")
)
.get_column("starts_with_check")
.all()
)
assert obs_name_starts_with_group


@pytest.mark.usefixtures("copy_poly_case")
def test_that_update_works_with_failed_realizations():
Expand Down
77 changes: 51 additions & 26 deletions tests/ert/ui_tests/cli/test_field_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,63 @@
from .run_cli import run_cli


def test_field_param_update_using_heat_equation(heat_equation_storage):
config = ErtConfig.from_file("config.ert")
with open_storage(config.ens_path, mode="w") as storage:
def test_shared_heat_equation_storage(heat_equation_storage):
"""The fixture heat_equation_storage runs the heat equation test case.
This test verifies that results are as expected.
"""
config = heat_equation_storage
with open_storage(config.ens_path) as storage:
experiment = storage.get_experiment_by_name("es-mda")
prior = experiment.get_ensemble_by_name("default_0")
posterior = experiment.get_ensemble_by_name("default_1")

prior_result = prior.load_parameters("COND")["values"]
ensembles = [experiment.get_ensemble_by_name(f"default_{i}") for i in range(4)]

param_config = config.ensemble_config.parameter_configs["COND"]
assert len(prior_result.x) == param_config.nx
assert len(prior_result.y) == param_config.ny
assert len(prior_result.z) == param_config.nz

posterior_result = posterior.load_parameters("COND")["values"]
prior_covariance = np.cov(
prior_result.values.reshape(
prior.ensemble_size, param_config.nx * param_config.ny * param_config.nz
),
rowvar=False,
)
posterior_covariance = np.cov(
posterior_result.values.reshape(
posterior.ensemble_size,
# Check that generalized variance decreases across consecutive ensembles
covariances = []
for ensemble in ensembles:
results = ensemble.load_parameters("COND")["values"]
reshaped_values = results.values.reshape(
ensemble.ensemble_size,
param_config.nx * param_config.ny * param_config.nz,
),
rowvar=False,
)
# Check that generalized variance is reduced by update step.
assert np.trace(prior_covariance) > np.trace(posterior_covariance)
)
covariances.append(np.cov(reshaped_values, rowvar=False))
for i in range(len(covariances) - 1):
assert np.trace(covariances[i]) > np.trace(
covariances[i + 1]
), f"Generalized variance did not decrease from iteration {i} to {i + 1}"

# Check that the saved cross-correlations are as expected.
for i in range(3):
ensemble = ensembles[i]
corr_XY = ensemble.load_cross_correlations()

assert sorted(corr_XY["param_group"].unique().to_list()) == [
"CORR_LENGTH",
"INIT_TEMP_SCALE",
]
assert corr_XY["param_name"].unique().to_list() == ["x"]

# Make sure correlations are between -1 and 1.
is_valid = (corr_XY["value"] >= -1) & (corr_XY["value"] <= 1)
assert is_valid.all()

# Check obs names and obs groups
expected_obs_groups = [obs[0] for obs in config.observation_config]
obs_groups = corr_XY["obs_group"].unique().to_list()
assert sorted(obs_groups) == sorted(expected_obs_groups)
# Check that obs names are created using obs groups
obs_name_starts_with_group = (
corr_XY.with_columns(
pl.col("obs_name")
.str.starts_with(pl.col("obs_group"))
.alias("starts_with_check")
)
.get_column("starts_with_check")
.all()
)
assert obs_name_starts_with_group

# Check that fields in the runpath are different between iterations
# Check that fields in the runpath are different between ensembles
cond_iter0 = resfo.read("simulations/realization-0/iter-0/cond.bgrdecl")[0][1]
cond_iter1 = resfo.read("simulations/realization-0/iter-1/cond.bgrdecl")[0][1]
assert (cond_iter0 != cond_iter1).all()
Expand Down

0 comments on commit c44eac7

Please sign in to comment.