From 683f3eeccfd602ce33ffebe5b167b9ec91db0c0e Mon Sep 17 00:00:00 2001 From: Jonathan Karlsen Date: Wed, 16 Oct 2024 16:49:04 +0200 Subject: [PATCH] Add separate validation argument for realization_in_ensemble --- src/ert/analysis/_es_update.py | 19 ++++-- src/ert/gui/simulation/manual_update_panel.py | 40 +++++------- src/ert/run_models/base_run_model.py | 1 + src/ert/storage/local_ensemble.py | 2 + .../realizations_in_ensemble_argument.py | 61 +++++++++++++++++++ 5 files changed, 94 insertions(+), 29 deletions(-) create mode 100644 src/ert/validation/realizations_in_ensemble_argument.py diff --git a/src/ert/analysis/_es_update.py b/src/ert/analysis/_es_update.py index 89707153b49..ff622908846 100644 --- a/src/ert/analysis/_es_update.py +++ b/src/ert/analysis/_es_update.py @@ -149,7 +149,7 @@ def _get_observations_and_responses( ) -> polars.DataFrame: """Fetches and aligns selected observations with their corresponding simulated responses from an ensemble.""" observations_by_type = ensemble.experiment.observations - + print(f"{iens_active_index=}") df = polars.DataFrame() for ( response_type, @@ -168,6 +168,7 @@ def _get_observations_and_responses( # Note that if there are duplicate entries for one # response at one index, they are aggregated together # with "mean" by default + print("PIVOTED HERE") pivoted = responses_for_type.pivot( on="realization", index=["response_key", *response_cls.primary_key], @@ -181,6 +182,7 @@ def _get_observations_and_responses( # one "time" column, which we will reindex towards the response dataset # with a given resolution if "time" in pivoted: + print("TIM EHERE") joined = observations_for_type.join_asof( pivoted, by=["response_key", *response_cls.primary_key], @@ -188,6 +190,7 @@ def _get_observations_and_responses( tolerance="1s", ) else: + print("JOINED HERE") joined = observations_for_type.join( pivoted, how="left", @@ -203,7 +206,7 @@ def _get_observations_and_responses( .drop(response_cls.primary_key) .rename({"__tmp_index_key__": "index"}) ) - + print("FIRST COLUMNS") first_columns = [ "response_key", "index", @@ -211,13 +214,16 @@ def _get_observations_and_responses( "observations", "std", ] + print("JOINED 0") joined = joined.select( first_columns + [c for c in joined.columns if c not in first_columns] ) - + print("JOINED 1") df.vstack(joined, in_place=True) + print("STACKED") ensemble.load_responses.cache_clear() + print("LOADED ENSMBLs") return df @@ -525,8 +531,9 @@ def analysis_ES( progress_callback: Callable[[AnalysisEvent], None], auto_scale_observations: Optional[List[ObservationGroups]], ) -> None: + print(f"{ens_mask=}") iens_active_index = np.flatnonzero(ens_mask) - + print(f"{iens_active_index=}") ensemble_size = ens_mask.sum() def adaptive_localization_progress_callback( @@ -707,7 +714,11 @@ def analysis_IES( sies_step_length: Callable[[int], float], initial_mask: npt.NDArray[np.bool_], ) -> ies.SIES: + print("B") + print(f"{ens_mask=}") iens_active_index = np.flatnonzero(ens_mask) + print(f"{iens_active_index=}") + print("END B") # Pick out realizations that were among the initials that are still living # Example: initial_mask=[1,1,1,0,1], ens_mask=[0,1,1,0,1] # Then the result is [0,1,1,1] diff --git a/src/ert/gui/simulation/manual_update_panel.py b/src/ert/gui/simulation/manual_update_panel.py index 23016fa9a17..14c80d21818 100644 --- a/src/ert/gui/simulation/manual_update_panel.py +++ b/src/ert/gui/simulation/manual_update_panel.py @@ -17,8 +17,10 @@ from ert.gui.simulation.experiment_config_panel import ExperimentConfigPanel from ert.mode_definitions import MANUAL_UPDATE_MODE from ert.run_models.manual_update import ManualUpdate -from ert.validation import ProperNameFormatArgument, RangeStringArgument -from ert.validation.rangestring import rangestring_to_list +from ert.validation import ProperNameFormatArgument +from ert.validation.realizations_in_ensemble_argument import ( + RealizationsInEnsembleArgument, +) @dataclass @@ -47,6 +49,11 @@ def __init__( lab.setAlignment(QtCore.Qt.AlignmentFlag.AlignLeft) layout.addRow(lab) self._ensemble_selector = EnsembleSelector(notifier) + self._ensemble_selector.ensemble_populated.connect(self._realizations_from_fs) + self._ensemble_selector.ensemble_populated.connect( + self.simulationConfigurationChanged + ) + self._ensemble_selector.currentIndexChanged.connect(self._realizations_from_fs) layout.addRow("Ensemble:", self._ensemble_selector) runpath_label = CopyableLabel(text=run_path) layout.addRow("Runpath:", runpath_label) @@ -73,9 +80,10 @@ def __init__( ActiveRealizationsModel(ensemble_size, show_default=False), # type: ignore "config/simulation/active_realizations", ) - self._active_realizations_field.setValidator( - RangeStringArgument(ensemble_size), + self._realizations_validator = RealizationsInEnsembleArgument( + self._ensemble_selector.selected_ensemble ) + self._active_realizations_field.setValidator(self._realizations_validator) self._realizations_from_fs() layout.addRow("Active realizations", self._active_realizations_field) @@ -84,34 +92,14 @@ def __init__( self._active_realizations_field.getValidationSupport().validationChanged.connect( self.simulationConfigurationChanged ) - self._ensemble_selector.ensemble_populated.connect(self._realizations_from_fs) - self._ensemble_selector.ensemble_populated.connect( - self.simulationConfigurationChanged - ) - self._ensemble_selector.currentIndexChanged.connect(self._realizations_from_fs) def isConfigurationValid(self) -> bool: return ( self._active_realizations_field.isValid() and self._ensemble_selector.currentIndex() != -1 - and self._validate_selected_realization_exist() + and self._active_realizations_field.isValid() ) - def _validate_selected_realization_exist(self): - realizations = rangestring_to_list(self._active_realizations_field.text()) - if len(realizations) < 1: - print("NO REALIZATIONS GIVEN") - return False - selected_ensemble = self._ensemble_selector.selected_ensemble - for realization_index in realizations: - if not selected_ensemble._responses_exist_for_realization( - realization_index - ): - print(f"{realization_index=} does not exist!") - return False - print(f"VALID REALIZATIONS {realizations=}") - return True - def get_experiment_arguments(self) -> Arguments: return Arguments( mode=MANUAL_UPDATE_MODE, @@ -122,6 +110,8 @@ def get_experiment_arguments(self) -> Arguments: def _realizations_from_fs(self) -> None: ensemble = self._ensemble_selector.selected_ensemble + print(f"{ensemble=}") + self._realizations_validator.__ensemble = ensemble if ensemble: parameters = ensemble.get_realization_mask_with_parameters() responses = ensemble.get_realization_mask_with_responses() diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 188b0ad8c1f..dcda6fb756f 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -243,6 +243,7 @@ def send_smoother_event( ) ) elif isinstance(event, AnalysisErrorEvent): + print("FOUND ME") self.send_event( RunModelErrorEvent( iteration=iteration, diff --git a/src/ert/storage/local_ensemble.py b/src/ert/storage/local_ensemble.py index 831c26d1f0b..517cd9c056d 100644 --- a/src/ert/storage/local_ensemble.py +++ b/src/ert/storage/local_ensemble.py @@ -627,6 +627,8 @@ def save_cross_correlations( @lru_cache # noqa: B019 def load_responses(self, key: str, realizations: Tuple[int]) -> polars.DataFrame: + print(f"{realizations=}") + print(f"{key=}") """Load responses for key and realizations into xarray Dataset. For each given realization, response data is loaded from the NetCDF diff --git a/src/ert/validation/realizations_in_ensemble_argument.py b/src/ert/validation/realizations_in_ensemble_argument.py new file mode 100644 index 00000000000..28e3b0d3ca3 --- /dev/null +++ b/src/ert/validation/realizations_in_ensemble_argument.py @@ -0,0 +1,61 @@ +from ert.storage import Ensemble +from ert.validation.range_string_argument import RangeStringArgument +from ert.validation.rangestring import rangestring_to_list + +from .validation_status import ValidationStatus + + +class RealizationsInEnsembleArgument(RangeStringArgument): + UNINITIALIZED_REALIZATIONS_SPECIFIED = ( + "The specified realization(s) %s are not found in selected ensemble." + ) + NO_REALIZATIONS_SPECIFIED = "No realizations specified." + + def __init__(self, ensemble: Ensemble, **kwargs: bool) -> None: + super().__init__(**kwargs) + self.__ensemble = ensemble + + def validate(self, token: str) -> ValidationStatus: + if not token: + return ValidationStatus() + + validation_status = super().validate(token) + if not validation_status: + return validation_status + print(f"{validation_status.message()=}") + attempted_realizations = rangestring_to_list(token) # might be duplicate + ### NOT NECCESSARY IF WE WANT EMPTY TO BE VALID. IN THIS CONTEXT, 'NO' REALIZATION IS VALID, AS IT WILL BE IN ENSEMBLE + if len(attempted_realizations) < 1: + validation_status.setFailed() + validation_status.addToMessage( + RealizationsInEnsembleArgument.NO_REALIZATIONS_SPECIFIED + ) + if self.__ensemble is None: + validation_status.setFailed() + validation_status.addToMessage("NO ENSEMBLE FOUND!") + return validation_status + invalid_realizations = [] + for realization in attempted_realizations: + if not self._validate_selected_realization_exist(realization): + invalid_realizations.append(realization) + + if invalid_realizations: + print(f"{self.__ensemble.get_realization_mask_with_responses()=}") + validation_status.setFailed() + validation_status.addToMessage( + RealizationsInEnsembleArgument.UNINITIALIZED_REALIZATIONS_SPECIFIED + % str(invalid_realizations) + ) + + elif not validation_status.failed(): + validation_status.setValue(token) + + return validation_status + + def _validate_selected_realization_exist(self, realization: int): + if self.__ensemble._responses_exist_for_realization(realization): + print(f"VALID REALIZATION {realization=}") + return True + + # print(f"{realization=} does not exist!") + return False