Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid redundant dataset loading #9016

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ert/config/gen_data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class GenDataConfig(ResponseConfig):
report_steps_list: List[Optional[List[int]]] = dataclasses.field(
default_factory=list
)
has_finalized_keys: bool = True

def __post_init__(self) -> None:
if len(self.report_steps_list) == 0:
Expand Down
1 change: 1 addition & 0 deletions src/ert/config/response_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class ResponseConfig(ABC):
name: str
input_files: List[str] = dataclasses.field(default_factory=list)
keys: List[str] = dataclasses.field(default_factory=list)
has_finalized_keys: bool = False

@abstractmethod
def read_from_file(self, run_path: str, iens: int) -> polars.DataFrame:
Expand Down
1 change: 1 addition & 0 deletions src/ert/config/summary_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
class SummaryConfig(ResponseConfig):
name: str = "summary"
refcase: Union[Set[datetime], List[str], None] = None
has_finalized_keys = False

def __post_init__(self) -> None:
if isinstance(self.refcase, list):
Expand Down
132 changes: 80 additions & 52 deletions src/ert/dark_storage/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import logging
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
from uuid import UUID

Expand All @@ -11,6 +12,8 @@
from ert.config.field import Field
from ert.storage import Ensemble, Experiment, Storage

logger = logging.getLogger(__name__)

response_key_to_displayed_key: Dict[str, Callable[[Tuple[Any, ...]], str]] = {
"summary": lambda t: t[0],
"gen_data": lambda t: f"{t[0]}@{t[1]}",
Expand Down Expand Up @@ -73,7 +76,13 @@ def ensemble_parameters(storage: Storage, ensemble_id: UUID) -> List[Dict[str, A
return param_list


def gen_data_keys(ensemble: Ensemble) -> Iterator[str]:
def get_response_names(ensemble: Ensemble) -> List[str]:
result = ensemble.experiment.response_type_to_response_keys["summary"]
result.extend(sorted(gen_data_display_keys(ensemble), key=lambda k: k.lower()))
return result


def gen_data_display_keys(ensemble: Ensemble) -> Iterator[str]:
gen_data_config = ensemble.experiment.response_configuration.get("gen_data")

if gen_data_config:
Expand All @@ -95,35 +104,75 @@ def data_for_key(
"""Returns a pandas DataFrame with the datapoints for a given key for a
given ensemble. The row index is the realization number, and the columns are an
index over the indexes/dates"""

if key.startswith("LOG10_"):
key = key[6:]

try:
summary_data = ensemble.load_responses(
"summary", tuple(ensemble.get_realization_list_with_responses("summary"))
)
summary_keys = summary_data["response_key"].unique().to_list()
except (ValueError, KeyError, polars.exceptions.ColumnNotFoundError):
summary_data = polars.DataFrame()
summary_keys = []

if key in summary_keys:
df = (
summary_data.filter(polars.col("response_key").eq(key))
.rename({"time": "Date", "realization": "Realization"})
.drop("response_key")
.to_pandas()
response_key_to_response_type = ensemble.experiment.response_key_to_response_type

# Check for exact match first. For example if key is "FOPRH"
# it may stop at "FOPR", which would be incorrect
response_key = next((k for k in response_key_to_response_type if k == key), None)
if response_key is None:
response_key = next(
(k for k in response_key_to_response_type if k in key), None
)
df = df.set_index(["Date", "Realization"])
# This performs the same aggragation by mean of duplicate values
# as in ert/analysis/_es_update.py
df = df.groupby(["Date", "Realization"]).mean()
data = df.unstack(level="Date")
data.columns = data.columns.droplevel(0)
try:
return data.astype(float)
except ValueError:
return data

if response_key is not None:
response_type = response_key_to_response_type[response_key]

if response_type == "summary":
summary_data = ensemble.load_responses(
response_key,
tuple(ensemble.get_realization_list_with_responses(response_key)),
)
if summary_data.is_empty():
return pd.DataFrame()

df = (
summary_data.rename({"time": "Date", "realization": "Realization"})
.drop("response_key")
.to_pandas()
)
df = df.set_index(["Date", "Realization"])
# This performs the same aggragation by mean of duplicate values
# as in ert/analysis/_es_update.py
df = df.groupby(["Date", "Realization"]).mean()
data = df.unstack(level="Date")
data.columns = data.columns.droplevel(0)
try:
return data.astype(float)
except ValueError:
return data

if response_type == "gen_data":
try:
# Call below will ValueError if key ends with H,
oyvindeide marked this conversation as resolved.
Show resolved Hide resolved
# requested via PlotAPI.history_data
response_key, report_step = displayed_key_to_response_key["gen_data"](
key
)
mask = ensemble.get_realization_mask_with_responses(response_key)
realizations = np.where(mask)[0]
data = ensemble.load_responses(response_key, tuple(realizations))
except ValueError as err:
logger.info(f"Dark storage could not load response {key}: {err}")
return pd.DataFrame()

try:
vals = data.filter(polars.col("report_step").eq(report_step))
oyvindeide marked this conversation as resolved.
Show resolved Hide resolved
pivoted = vals.drop("response_key", "report_step").pivot(
on="index", values="values"
)
data = pivoted.to_pandas().set_index("realization")
data.columns = data.columns.astype(int)
data.columns.name = "axis"
try:
return data.astype(float)
except ValueError:
return data
oyvindeide marked this conversation as resolved.
Show resolved Hide resolved
except (ValueError, KeyError):
return pd.DataFrame()

group = key.split(":")[0]
parameters = ensemble.experiment.parameter_configuration
Expand Down Expand Up @@ -162,30 +211,6 @@ def data_for_key(
return data.astype(float)
except ValueError:
return data
if key in gen_data_keys(ensemble):
oyvindeide marked this conversation as resolved.
Show resolved Hide resolved
response_key, report_step = displayed_key_to_response_key["gen_data"](key)
try:
mask = ensemble.get_realization_mask_with_responses(response_key)
realizations = np.where(mask)[0]
data = ensemble.load_responses(response_key, tuple(realizations))
except ValueError as err:
print(f"Could not load response {key}: {err}")
return pd.DataFrame()

try:
vals = data.filter(polars.col("report_step").eq(report_step))
pivoted = vals.drop("response_key", "report_step").pivot(
on="index", values="values"
)
data = pivoted.to_pandas().set_index("realization")
data.columns = data.columns.astype(int)
data.columns.name = "axis"
try:
return data.astype(float)
except ValueError:
return data
except (ValueError, KeyError):
return pd.DataFrame()

return pd.DataFrame()

Expand Down Expand Up @@ -245,7 +270,7 @@ def get_observation_keys_for_response(
Get all observation keys for given response key
"""

if displayed_response_key in gen_data_keys(ensemble):
if displayed_response_key in gen_data_display_keys(ensemble):
response_key, report_step = displayed_key_to_response_key["gen_data"](
displayed_response_key
)
Expand All @@ -262,7 +287,10 @@ def get_observation_keys_for_response(

return filtered["observation_key"].unique().to_list()

elif displayed_response_key in ensemble.get_summary_keyset():
elif (
displayed_response_key
in ensemble.experiment.response_type_to_response_keys["summary"]
):
response_key = displayed_key_to_response_key["summary"](displayed_response_key)[
0
]
Expand Down
15 changes: 12 additions & 3 deletions src/ert/dark_storage/endpoints/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from ert.dark_storage.common import (
data_for_key,
ensemble_parameters,
gen_data_keys,
gen_data_display_keys,
get_observation_keys_for_response,
get_observations_for_obs_keys,
response_key_to_displayed_key,
)
from ert.dark_storage.enkf import get_storage
from ert.storage import Storage
from ert.storage.realization_storage_state import RealizationStorageState

router = APIRouter(tags=["record"])

Expand Down Expand Up @@ -133,15 +134,23 @@ def get_ensemble_responses(
)
response_names_with_observations.update(set(obs_with_responses))

for name in ensemble.get_summary_keyset():
has_responses = any(
s == RealizationStorageState.HAS_DATA for s in ensemble.get_ensemble_state()
)

for name in (
ensemble.experiment.response_type_to_response_keys.get("summary", [])
if has_responses
else []
):
response_map[str(name)] = js.RecordOut(
id=UUID(int=0),
name=name,
userdata={"data_origin": "Summary"},
has_observations=name in response_names_with_observations,
)

for name in gen_data_keys(ensemble):
for name in gen_data_display_keys(ensemble) if has_responses else []:
response_map[str(name)] = js.RecordOut(
id=UUID(int=0),
name=name,
Expand Down
46 changes: 19 additions & 27 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
logger = logging.getLogger(__name__)

import polars
from polars.exceptions import ColumnNotFoundError


class _Index(BaseModel):
Expand Down Expand Up @@ -311,11 +310,22 @@ def _has_response(_key: str) -> bool:
if key:
return _has_response(key)

return all(
_has_response(response)
for response in self.experiment.response_configuration
is_expecting_any_responses = any(
bool(config.keys)
for config in self.experiment.response_configuration.values()
)

if not is_expecting_any_responses:
return True

non_empty_response_configs = [
response
for response, config in self.experiment.response_configuration.items()
if bool(config.keys)
]

return all(_has_response(response) for response in non_empty_response_configs)

def is_initalized(self) -> List[int]:
"""
Return the realization numbers where all parameters are internalized. In
Expand Down Expand Up @@ -502,27 +512,6 @@ def _find_state(realization: int) -> RealizationStorageState:

return [_find_state(i) for i in range(self.ensemble_size)]

def get_summary_keyset(self) -> List[str]:
"""
Find the first folder with summary data then load the
summary keys from this.

Returns
-------
keys : list of str
List of summary keys.
"""

try:
summary_data = self.load_responses(
"summary",
tuple(self.get_realization_list_with_responses("summary")),
)

return sorted(summary_data["response_key"].unique().to_list())
except (ValueError, KeyError, ColumnNotFoundError):
return []

def _load_single_dataset(
self,
group: str,
Expand Down Expand Up @@ -696,8 +685,6 @@ def load_all_summary_data(
raise IndexError(f"No such realization {realization_index}")
realizations = [realization_index]

summary_keys = self.get_summary_keyset()

try:
df_pl = self.load_responses("summary", tuple(realizations))

Expand All @@ -715,6 +702,7 @@ def load_all_summary_data(
)

if keys:
summary_keys = self.experiment.response_type_to_response_keys["summary"]
summary_keys = sorted(
[key for key in keys if key in summary_keys]
) # ignore keys that doesn't exist
Expand Down Expand Up @@ -877,6 +865,10 @@ def save_response(
output_path / f"{response_type}.parquet", data
)

if not self.experiment._has_finalized_response_keys(response_type):
response_keys = data["response_key"].unique().to_list()
self.experiment._update_response_keys(response_type, response_keys)

def calculate_std_dev_for_parameter(self, parameter_group: str) -> xr.Dataset:
if parameter_group not in self.experiment.parameter_configuration:
raise ValueError(f"{parameter_group} is not registered to the experiment.")
Expand Down
Loading
Loading