Skip to content

Commit

Permalink
Add separate validation argument for realization_in_ensemble
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-eq committed Oct 16, 2024
1 parent c3214c2 commit 683f3ee
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 29 deletions.
19 changes: 15 additions & 4 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _get_observations_and_responses(
) -> polars.DataFrame:
"""Fetches and aligns selected observations with their corresponding simulated responses from an ensemble."""
observations_by_type = ensemble.experiment.observations

print(f"{iens_active_index=}")
df = polars.DataFrame()
for (
response_type,
Expand All @@ -168,6 +168,7 @@ def _get_observations_and_responses(
# Note that if there are duplicate entries for one
# response at one index, they are aggregated together
# with "mean" by default
print("PIVOTED HERE")
pivoted = responses_for_type.pivot(
on="realization",
index=["response_key", *response_cls.primary_key],
Expand All @@ -181,13 +182,15 @@ def _get_observations_and_responses(
# one "time" column, which we will reindex towards the response dataset
# with a given resolution
if "time" in pivoted:
print("TIM EHERE")
joined = observations_for_type.join_asof(
pivoted,
by=["response_key", *response_cls.primary_key],
on="time",
tolerance="1s",
)
else:
print("JOINED HERE")
joined = observations_for_type.join(
pivoted,
how="left",
Expand All @@ -203,21 +206,24 @@ def _get_observations_and_responses(
.drop(response_cls.primary_key)
.rename({"__tmp_index_key__": "index"})
)

print("FIRST COLUMNS")
first_columns = [
"response_key",
"index",
"observation_key",
"observations",
"std",
]
print("JOINED 0")
joined = joined.select(
first_columns + [c for c in joined.columns if c not in first_columns]
)

print("JOINED 1")
df.vstack(joined, in_place=True)
print("STACKED")

ensemble.load_responses.cache_clear()
print("LOADED ENSMBLs")
return df


Expand Down Expand Up @@ -525,8 +531,9 @@ def analysis_ES(
progress_callback: Callable[[AnalysisEvent], None],
auto_scale_observations: Optional[List[ObservationGroups]],
) -> None:
print(f"{ens_mask=}")
iens_active_index = np.flatnonzero(ens_mask)

print(f"{iens_active_index=}")
ensemble_size = ens_mask.sum()

def adaptive_localization_progress_callback(
Expand Down Expand Up @@ -707,7 +714,11 @@ def analysis_IES(
sies_step_length: Callable[[int], float],
initial_mask: npt.NDArray[np.bool_],
) -> ies.SIES:
print("B")
print(f"{ens_mask=}")
iens_active_index = np.flatnonzero(ens_mask)
print(f"{iens_active_index=}")
print("END B")
# Pick out realizations that were among the initials that are still living
# Example: initial_mask=[1,1,1,0,1], ens_mask=[0,1,1,0,1]
# Then the result is [0,1,1,1]
Expand Down
40 changes: 15 additions & 25 deletions src/ert/gui/simulation/manual_update_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from ert.gui.simulation.experiment_config_panel import ExperimentConfigPanel
from ert.mode_definitions import MANUAL_UPDATE_MODE
from ert.run_models.manual_update import ManualUpdate
from ert.validation import ProperNameFormatArgument, RangeStringArgument
from ert.validation.rangestring import rangestring_to_list
from ert.validation import ProperNameFormatArgument
from ert.validation.realizations_in_ensemble_argument import (
RealizationsInEnsembleArgument,
)


@dataclass
Expand Down Expand Up @@ -47,6 +49,11 @@ def __init__(
lab.setAlignment(QtCore.Qt.AlignmentFlag.AlignLeft)
layout.addRow(lab)
self._ensemble_selector = EnsembleSelector(notifier)
self._ensemble_selector.ensemble_populated.connect(self._realizations_from_fs)
self._ensemble_selector.ensemble_populated.connect(
self.simulationConfigurationChanged
)
self._ensemble_selector.currentIndexChanged.connect(self._realizations_from_fs)
layout.addRow("Ensemble:", self._ensemble_selector)
runpath_label = CopyableLabel(text=run_path)
layout.addRow("Runpath:", runpath_label)
Expand All @@ -73,9 +80,10 @@ def __init__(
ActiveRealizationsModel(ensemble_size, show_default=False), # type: ignore
"config/simulation/active_realizations",
)
self._active_realizations_field.setValidator(
RangeStringArgument(ensemble_size),
self._realizations_validator = RealizationsInEnsembleArgument(
self._ensemble_selector.selected_ensemble
)
self._active_realizations_field.setValidator(self._realizations_validator)
self._realizations_from_fs()
layout.addRow("Active realizations", self._active_realizations_field)

Expand All @@ -84,34 +92,14 @@ def __init__(
self._active_realizations_field.getValidationSupport().validationChanged.connect(
self.simulationConfigurationChanged
)
self._ensemble_selector.ensemble_populated.connect(self._realizations_from_fs)
self._ensemble_selector.ensemble_populated.connect(
self.simulationConfigurationChanged
)
self._ensemble_selector.currentIndexChanged.connect(self._realizations_from_fs)

def isConfigurationValid(self) -> bool:
return (
self._active_realizations_field.isValid()
and self._ensemble_selector.currentIndex() != -1
and self._validate_selected_realization_exist()
and self._active_realizations_field.isValid()
)

def _validate_selected_realization_exist(self):
realizations = rangestring_to_list(self._active_realizations_field.text())
if len(realizations) < 1:
print("NO REALIZATIONS GIVEN")
return False
selected_ensemble = self._ensemble_selector.selected_ensemble
for realization_index in realizations:
if not selected_ensemble._responses_exist_for_realization(
realization_index
):
print(f"{realization_index=} does not exist!")
return False
print(f"VALID REALIZATIONS {realizations=}")
return True

def get_experiment_arguments(self) -> Arguments:
return Arguments(
mode=MANUAL_UPDATE_MODE,
Expand All @@ -122,6 +110,8 @@ def get_experiment_arguments(self) -> Arguments:

def _realizations_from_fs(self) -> None:
ensemble = self._ensemble_selector.selected_ensemble
print(f"{ensemble=}")
self._realizations_validator.__ensemble = ensemble
if ensemble:
parameters = ensemble.get_realization_mask_with_parameters()
responses = ensemble.get_realization_mask_with_responses()
Expand Down
1 change: 1 addition & 0 deletions src/ert/run_models/base_run_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def send_smoother_event(
)
)
elif isinstance(event, AnalysisErrorEvent):
print("FOUND ME")
self.send_event(
RunModelErrorEvent(
iteration=iteration,
Expand Down
2 changes: 2 additions & 0 deletions src/ert/storage/local_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,8 @@ def save_cross_correlations(

@lru_cache # noqa: B019
def load_responses(self, key: str, realizations: Tuple[int]) -> polars.DataFrame:
print(f"{realizations=}")
print(f"{key=}")
"""Load responses for key and realizations into xarray Dataset.
For each given realization, response data is loaded from the NetCDF
Expand Down
61 changes: 61 additions & 0 deletions src/ert/validation/realizations_in_ensemble_argument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from ert.storage import Ensemble
from ert.validation.range_string_argument import RangeStringArgument
from ert.validation.rangestring import rangestring_to_list

from .validation_status import ValidationStatus


class RealizationsInEnsembleArgument(RangeStringArgument):
UNINITIALIZED_REALIZATIONS_SPECIFIED = (
"The specified realization(s) %s are not found in selected ensemble."
)
NO_REALIZATIONS_SPECIFIED = "No realizations specified."

def __init__(self, ensemble: Ensemble, **kwargs: bool) -> None:
super().__init__(**kwargs)
self.__ensemble = ensemble

def validate(self, token: str) -> ValidationStatus:
if not token:
return ValidationStatus()

validation_status = super().validate(token)
if not validation_status:
return validation_status
print(f"{validation_status.message()=}")
attempted_realizations = rangestring_to_list(token) # might be duplicate
### NOT NECCESSARY IF WE WANT EMPTY TO BE VALID. IN THIS CONTEXT, 'NO' REALIZATION IS VALID, AS IT WILL BE IN ENSEMBLE
if len(attempted_realizations) < 1:
validation_status.setFailed()
validation_status.addToMessage(
RealizationsInEnsembleArgument.NO_REALIZATIONS_SPECIFIED
)
if self.__ensemble is None:
validation_status.setFailed()
validation_status.addToMessage("NO ENSEMBLE FOUND!")
return validation_status
invalid_realizations = []
for realization in attempted_realizations:
if not self._validate_selected_realization_exist(realization):
invalid_realizations.append(realization)

if invalid_realizations:
print(f"{self.__ensemble.get_realization_mask_with_responses()=}")
validation_status.setFailed()
validation_status.addToMessage(
RealizationsInEnsembleArgument.UNINITIALIZED_REALIZATIONS_SPECIFIED
% str(invalid_realizations)
)

elif not validation_status.failed():
validation_status.setValue(token)

return validation_status

def _validate_selected_realization_exist(self, realization: int):

Check failure on line 55 in src/ert/validation/realizations_in_ensemble_argument.py

View workflow job for this annotation

GitHub Actions / type-checking (3.12)

Function is missing a return type annotation
if self.__ensemble._responses_exist_for_realization(realization):
print(f"VALID REALIZATION {realization=}")
return True

# print(f"{realization=} does not exist!")
return False

0 comments on commit 683f3ee

Please sign in to comment.