Skip to content

Commit

Permalink
Fixes #187 and #191: better masking, ensembler numbers per prediction…
Browse files Browse the repository at this point in the history
… added in
  • Loading branch information
JimCircadian committed Sep 22, 2023
1 parent c573671 commit 2606519
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 20 deletions.
2 changes: 1 addition & 1 deletion icenet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
__copyright__ = "British Antarctic Survey"
__email__ = "[email protected]"
__license__ = "MIT"
__version__ = "0.2.7a0"
__version__ = "0.2.7a1"
46 changes: 27 additions & 19 deletions icenet/process/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_refcube(north: bool = True, south: bool = False) -> object:

def get_prediction_data(root: object,
name: object,
date: object) -> object:
date: object) -> tuple:
"""
:param root:
Expand All @@ -85,12 +85,13 @@ def get_prediction_data(root: object,

data = [np.load(f) for f in np_files]
data = np.array(data)
ens_members = data.shape[0]

logging.debug("Data read from disk: {} from: {}".format(data.shape, np_files))

return np.stack(
[data.mean(axis=0), data.std(axis=0)],
axis=-1).squeeze()
axis=-1).squeeze(), ens_members


def date_arg(string: str) -> object:
Expand Down Expand Up @@ -148,9 +149,9 @@ def create_cf_output():
for s in args.datefile.read().split()]
args.datefile.close()

arr = np.array(
[get_prediction_data(args.root, args.name, date)
for date in dates])
arr, ens_members = zip(*[get_prediction_data(args.root, args.name, date) for date in dates])
ens_members = list(ens_members)
arr = np.array(arr)

logging.info("Dataset arr shape: {}".format(arr.shape))

Expand All @@ -160,6 +161,19 @@ def create_cf_output():
if args.mask:
mask_gen = Masks(north=ds.north, south=ds.south)

if args.agcm:
logging.info("Applying active grid cell masks")

for idx, forecast_date in enumerate(dates):
for lead_idx in np.arange(0, arr.shape[3], 1):
lead_dt = forecast_date + dt.timedelta(days=int(lead_idx) + 1)
logging.debug("Active grid cell mask start {} forecast date {}".
format(forecast_date, lead_dt))

grid_cell_mask = mask_gen.get_active_cell_mask(lead_dt.month)
sic_mean[idx, ~grid_cell_mask, lead_idx] = 0
sic_stddev[idx, ~grid_cell_mask, lead_idx] = 0

if args.land:
logging.info("Land masking the forecast output")
land_mask = mask_gen.get_land_mask()
Expand All @@ -176,19 +190,6 @@ def create_cf_output():
sic_mean[mask] = 0
sic_stddev[mask] = 0

if args.agcm:
logging.info("Applying active grid cell masks")

for idx, forecast_date in enumerate(dates):
for lead_idx in np.arange(0, arr.shape[3], 1):
lead_dt = forecast_date + dt.timedelta(days=int(lead_idx) + 1)
logging.debug("Active grid cell mask start {} forecast date {}".
format(forecast_date, lead_dt))

grid_cell_mask = mask_gen.get_active_cell_mask(lead_dt.month)
sic_mean[idx, ~grid_cell_mask, lead_idx] = 0
sic_stddev[idx, ~grid_cell_mask, lead_idx] = 0

lists_of_fcast_dates = [
[pd.Timestamp(date + dt.timedelta(days=int(lead_idx)))
for lead_idx in np.arange(1, arr.shape[3] + 1, 1)]
Expand All @@ -200,6 +201,7 @@ def create_cf_output():
Lambert_Azimuthal_Grid=ref_sic.Lambert_Azimuthal_Grid,
sic_mean=(["time", "yc", "xc", "leadtime"], sic_mean),
sic_stddev=(["time", "yc", "xc", "leadtime"], sic_stddev),
ensemble_members=(["time"], ens_members),
),
coords=dict(
time=[pd.Timestamp(d) for d in dates],
Expand Down Expand Up @@ -260,7 +262,7 @@ def create_cf_output():
standard_name_vocabulary="CF Standard Name Table v27",
summary="""
This is an output of sea ice concentration predictions from the
IceNet UNet run in an ensemble, with postprocessing to determine
IceNet run in an ensemble, with postprocessing to determine
the mean and standard deviation across the runs.
""",
# Use ISO 8601:2004 duration format, preferably the extended format
Expand Down Expand Up @@ -332,6 +334,12 @@ def create_cf_output():
units="1",
)

xarr.ensemble_members.attrs = dict(
long_name="number of ensemble members used to create this prediction",
short_name="ensemble_members",
# units="1",
)

# TODO: split into daily files
output_path = os.path.join(args.output_dir, "{}.nc".format(args.name))
logging.info("Saving to {}".format(output_path))
Expand Down

0 comments on commit 2606519

Please sign in to comment.