diff --git a/icenet/__init__.py b/icenet/__init__.py index 01932438..dbd51d00 100644 --- a/icenet/__init__.py +++ b/icenet/__init__.py @@ -4,4 +4,4 @@ __copyright__ = "British Antarctic Survey" __email__ = "jambyr@bas.ac.uk" __license__ = "MIT" -__version__ = "0.2.7a0" +__version__ = "0.2.7a1" diff --git a/icenet/process/predict.py b/icenet/process/predict.py index c9cc0709..62a74d95 100644 --- a/icenet/process/predict.py +++ b/icenet/process/predict.py @@ -61,7 +61,7 @@ def get_refcube(north: bool = True, south: bool = False) -> object: def get_prediction_data(root: object, name: object, - date: object) -> object: + date: object) -> tuple: """ :param root: @@ -85,12 +85,13 @@ def get_prediction_data(root: object, data = [np.load(f) for f in np_files] data = np.array(data) + ens_members = data.shape[0] logging.debug("Data read from disk: {} from: {}".format(data.shape, np_files)) return np.stack( [data.mean(axis=0), data.std(axis=0)], - axis=-1).squeeze() + axis=-1).squeeze(), ens_members def date_arg(string: str) -> object: @@ -148,9 +149,9 @@ def create_cf_output(): for s in args.datefile.read().split()] args.datefile.close() - arr = np.array( - [get_prediction_data(args.root, args.name, date) - for date in dates]) + arr, ens_members = zip(*[get_prediction_data(args.root, args.name, date) for date in dates]) + ens_members = list(ens_members) + arr = np.array(arr) logging.info("Dataset arr shape: {}".format(arr.shape)) @@ -160,6 +161,19 @@ def create_cf_output(): if args.mask: mask_gen = Masks(north=ds.north, south=ds.south) + if args.agcm: + logging.info("Applying active grid cell masks") + + for idx, forecast_date in enumerate(dates): + for lead_idx in np.arange(0, arr.shape[3], 1): + lead_dt = forecast_date + dt.timedelta(days=int(lead_idx) + 1) + logging.debug("Active grid cell mask start {} forecast date {}". + format(forecast_date, lead_dt)) + + grid_cell_mask = mask_gen.get_active_cell_mask(lead_dt.month) + sic_mean[idx, ~grid_cell_mask, lead_idx] = 0 + sic_stddev[idx, ~grid_cell_mask, lead_idx] = 0 + if args.land: logging.info("Land masking the forecast output") land_mask = mask_gen.get_land_mask() @@ -176,19 +190,6 @@ def create_cf_output(): sic_mean[mask] = 0 sic_stddev[mask] = 0 - if args.agcm: - logging.info("Applying active grid cell masks") - - for idx, forecast_date in enumerate(dates): - for lead_idx in np.arange(0, arr.shape[3], 1): - lead_dt = forecast_date + dt.timedelta(days=int(lead_idx) + 1) - logging.debug("Active grid cell mask start {} forecast date {}". - format(forecast_date, lead_dt)) - - grid_cell_mask = mask_gen.get_active_cell_mask(lead_dt.month) - sic_mean[idx, ~grid_cell_mask, lead_idx] = 0 - sic_stddev[idx, ~grid_cell_mask, lead_idx] = 0 - lists_of_fcast_dates = [ [pd.Timestamp(date + dt.timedelta(days=int(lead_idx))) for lead_idx in np.arange(1, arr.shape[3] + 1, 1)] @@ -200,6 +201,7 @@ def create_cf_output(): Lambert_Azimuthal_Grid=ref_sic.Lambert_Azimuthal_Grid, sic_mean=(["time", "yc", "xc", "leadtime"], sic_mean), sic_stddev=(["time", "yc", "xc", "leadtime"], sic_stddev), + ensemble_members=(["time"], ens_members), ), coords=dict( time=[pd.Timestamp(d) for d in dates], @@ -260,7 +262,7 @@ def create_cf_output(): standard_name_vocabulary="CF Standard Name Table v27", summary=""" This is an output of sea ice concentration predictions from the - IceNet UNet run in an ensemble, with postprocessing to determine + IceNet run in an ensemble, with postprocessing to determine the mean and standard deviation across the runs. """, # Use ISO 8601:2004 duration format, preferably the extended format @@ -332,6 +334,12 @@ def create_cf_output(): units="1", ) + xarr.ensemble_members.attrs = dict( + long_name="number of ensemble members used to create this prediction", + short_name="ensemble_members", + # units="1", + ) + # TODO: split into daily files output_path = os.path.join(args.output_dir, "{}.nc".format(args.name)) logging.info("Saving to {}".format(output_path))