From 0d8ddd463b8e20d87c32329914575ee8c8dc090f Mon Sep 17 00:00:00 2001 From: "Yngve S. Kristiansen" Date: Wed, 25 Sep 2024 08:20:02 +0200 Subject: [PATCH] Refactor rft export logic for polars --- .../scripts/gen_data_rft_export.py | 117 ++++++++++-------- 1 file changed, 65 insertions(+), 52 deletions(-) diff --git a/src/ert/resources/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py b/src/ert/resources/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py index f985f5a19a4..6ad132d8a9f 100644 --- a/src/ert/resources/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py +++ b/src/ert/resources/workflows/jobs/internal-gui/scripts/gen_data_rft_export.py @@ -4,6 +4,7 @@ import numpy import pandas as pd +import polars from qtpy.QtWidgets import QCheckBox from ert.config import CancelPluginException, ErtPlugin @@ -85,8 +86,6 @@ def run( ensemble_data_as_json = None if len(workflow_args) < 3 else workflow_args[2] drop_const_cols = False if len(workflow_args) < 4 else bool(workflow_args[3]) - wells = set() - ensemble_data_as_dict = ( json.loads(ensemble_data_as_json) if ensemble_data_as_json else {} ) @@ -110,82 +109,96 @@ def run( f"The ensemble '{ensemble_name}' does not have any data!" ) - obs = ensemble.experiment.observations + obs_df = ensemble.experiment.observations.get("gen_data") obs_keys = [] - for key, _ in obs.items(): + for key in ensemble.experiment.observation_keys: if key.startswith("RFT_"): obs_keys.append(key) - if len(obs_keys) == 0: + if len(obs_keys) == 0 or obs_df is None: raise UserWarning( "The config does not contain any" " GENERAL_OBSERVATIONS starting with RFT_*" ) - ensemble_data = [] for obs_key in obs_keys: - well = obs_key.replace("RFT_", "") - wells.add(well) - obs_vector = obs[obs_key] - data_key = obs_vector.attrs["response"] - if len(obs_vector.report_step) == 1: - report_step = obs_vector.report_step.values - obs_node = obs_vector.sel(report_step=report_step) - else: + well_key = obs_key.replace("RFT_", "") + + obs_df = obs_df.filter(polars.col("observation_key").eq(obs_key)) + response_key = obs_df["response_key"].unique().to_list()[0] + + if len(obs_df["report_step"].unique()) != 1: raise UserWarning( "GEN_DATA RFT CSV Export can only be used for observations " "active for exactly one report step" ) - realizations = ensemble.get_realization_list_with_responses(data_key) - vals = ensemble.load_responses(data_key, tuple(realizations)).sel( - report_step=report_step, drop=True - ) - index = pd.Index(vals.index.values, name="axis") - rft_data = pd.DataFrame( - data=vals["values"].values.reshape(len(vals.realization), -1).T, - index=index, - columns=realizations, + realizations = ensemble.get_realization_list_with_responses( + response_key ) + responses = ensemble.load_responses(response_key, tuple(realizations)) + joined = obs_df.join( + responses, + on=["response_key", "report_step", "index"], + how="left", + ).drop("index", "report_step") # Trajectory - trajectory_file = os.path.join(trajectory_path, f"{well}.txt") + trajectory_file = os.path.join(trajectory_path, f"{well_key}.txt") if not os.path.isfile(trajectory_file): - trajectory_file = os.path.join(trajectory_path, f"{well}_R.txt") + trajectory_file = os.path.join(trajectory_path, f"{well_key}_R.txt") arg = load_args( trajectory_file, column_names=["utm_x", "utm_y", "md", "tvd"] ) tvd_arg = arg["tvd"] - # Observations - for iens in realizations: - realization_frame = pd.DataFrame( - data={ - "TVD": tvd_arg, - "Pressure": rft_data[iens], - "ObsValue": obs_node["observations"].values[0], - "ObsStd": obs_node["std"].values[0], - }, - columns=["TVD", "Pressure", "ObsValue", "ObsStd"], - ) - - realization_frame["Realization"] = iens - realization_frame["Well"] = well - realization_frame["Ensemble"] = ensemble_name - realization_frame["Iteration"] = ensemble.iteration - - ensemble_data.append(realization_frame) - - data.append(pd.concat(ensemble_data)) - - frame = pd.concat(data) - frame.set_index(["Realization", "Well", "Ensemble", "Iteration"], inplace=True) - if drop_const_cols: - frame = frame.loc[:, (frame != frame.iloc[0]).any()] + all_realization_frames = joined.rename( + { + "realization": "Realization", + "values": "Pressure", + "observations": "ObsValue", + "std": "ObsStd", + } + ).with_columns( + [ + polars.lit(well_key).alias("Well").cast(polars.String), + polars.lit(ensemble.name).alias("Ensemble").cast(polars.String), + polars.lit(ensemble.iteration) + .alias("Iteration") + .cast(polars.UInt8), + polars.lit(tvd_arg).alias("TVD").cast(polars.Float32), + ] + ) - frame.to_csv(output_file) - well_list_str = ", ".join(list(wells)) + data.append(all_realization_frames) + + frame = polars.concat(data) + + cols_index = ["Well", "Ensemble", "Iteration"] + const_cols_right = ["ObsValue", "ObsStd"] + const_cols_left = [ + col + for col in frame.columns + if ( + col not in cols_index + and col not in const_cols_right + and frame[col].n_unique() == 1 + ) + ] + + columns_to_export = [ + "Realization", + *cols_index, + *(const_cols_left if not drop_const_cols else []), + *["Pressure"], + *(const_cols_right if not drop_const_cols else []), + ] + + to_export = frame.select(columns_to_export) + + to_export.write_csv(output_file, include_header=True) + well_list_str = ", ".join(to_export["Well"].unique().to_list()) export_info = ( f"Exported RFT information for wells: {well_list_str} to: {output_file}" )