Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better tests io #72

Merged
merged 10 commits into from
Oct 14, 2024
188 changes: 137 additions & 51 deletions py4cast/io/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pandas as pd
import torch
import xarray as xr
from cfgrib import xarray_to_grib as xtg
from dataclasses_json import dataclass_json
Expand Down Expand Up @@ -86,70 +87,155 @@ def save_named_tensors_to_grib(

for t_idx in range(predicted_time_steps)[:1]:
for group in model_ds.keys():
target_ds = deepcopy(model_ds[group])

# if the shape of the dataset grid doesn't match grib template, fill the rest of the data with NaNs
nanmask, latlon = make_nan_mask(ds, target_ds)
(
latmin,
latmax,
longmin,
longmax,
) = latlon

target_ds["time"] = sample.date
ns_step = np.timedelta64(
int(leadtimes[t_idx] * 3600 * 1000000000),
"ns",
raw_data = pred.select_dim("timestep", t_idx, bare_tensor=False)
storable = write_storable_dataset(
pred,
ds,
model_ds[group],
group,
sample,
validtimes[t_idx],
leadtimes[t_idx],
raw_data,
grib_features,
)
ns_valid = np.timedelta64(
int(validtimes[t_idx] * 3600 * 1000000000),
"ns",
)
target_ds["step"] = ns_step
target_ds["valid_time"] = np.datetime64(sample.date) + ns_valid

# collapsing batch dimension and selecting a given timestep

data_tidx = pred.select_dim("timestep", t_idx, bare_tensor=False)
for feature_name in pred.feature_names_to_idx:

level, name, tol = grib_features.loc[
feature_name, ["level", "name", "typeOfLevel"]
]

if (
(f"{name}_{tol}" == group)
or (f"{level}_{tol}" == group)
or (tol == group)
):

data = data_tidx[feature_name].squeeze().cpu().numpy()

if nanmask is None:
data2grib = data
else:
data2grib = nanmask
data2grib[latmax : latmin + 1, longmin : longmax + 1] = data

dims = model_ds[group][name].dims
target_ds[name] = (dims, data2grib)
target_ds[name] = target_ds[name].assign_attrs(
**model_ds[group][name].attrs
)
filename = get_output_filename(saving_settings, sample, leadtimes[t_idx])
option = (
"wb"
if not os.path.exists(f"{saving_settings.directory}/{filename}")
else "ab"
)
xtg.to_grib(
target_ds,
storable,
Path(saving_settings.directory) / filename,
option,
)


def write_storable_dataset(
pred: NamedTensor,
ds: DatasetABC,
template_ds: xr.Dataset,
group: str,
sample: Any,
validtime: float,
leadtime: float,
raw_data: NamedTensor,
grib_features: pd.DataFrame,
) -> xr.Dataset:
"""Write the template xarray dataset with raw data tensor from inference

Args:
pred (NamedTensor): complete namedtensor, containing feature names
ds (DatasetABC): inference dataset
template_ds (xr.Dataset): xarray dataset extracted from the template grib
group (str): index of the template_ds in the template dict containing coherent groups
sample (Any): the inference sample to be saved
validtime (float): time of validity of the current sample
leadtime (float): lead time of the current sample
raw_data (NamedTensor): extraction from pred at current timestep
grib_features (pd.DataFrame): complete description of feature names and definition

Returns:
xr.Dataset: the template dataset, filled with data from raw_data, in the correct format
"""
receiver_ds = deepcopy(template_ds)

# if the shape of the dataset grid doesn't match grib template, fill the rest of the data with NaNs
nanmask, latlon = make_nan_mask(ds, receiver_ds)
(
latmin,
latmax,
longmin,
longmax,
) = latlon

receiver_ds["time"] = sample.date
ns_step = np.timedelta64(
int(leadtime * 3600 * 1000000000),
"ns",
)
ns_valid = np.timedelta64(
int(validtime * 3600 * 1000000000),
"ns",
)
receiver_ds["step"] = ns_step
receiver_ds["valid_time"] = np.datetime64(sample.date) + ns_valid

# retrieving key metadata to be able to parallelize writing
namelist = list(receiver_ds.keys())
used_grib_feat = grib_features[(grib_features["name"].isin(namelist))]

# will only be used if namelist has a single item
name = namelist[0]
feature_names = used_grib_feat["feature_name"].tolist()
tol = used_grib_feat["typeOfLevel"].drop_duplicates().tolist()[0]
feature_idx = torch.tensor([pred.feature_names_to_idx[f] for f in feature_names])

data = raw_data.index_select_dim("features", feature_idx).squeeze().cpu().numpy()

if f"{name}_{tol}" == group:
# there might be a third dimension (eg isobaricInhPa) : basis for nanmask duplication
dims = template_ds.dims

try:
# supplementary dim
maybe_repeat = template_ds.sizes[tol]
data2grib = np.repeat(nanmask[np.newaxis], maybe_repeat, axis=0)
data2grib[:, latmax : latmin + 1, longmin : longmax + 1] = data
receiver_ds[name] = (dims, data2grib.astype(np.float32))
receiver_ds[name] = receiver_ds[name].assign_attrs(
**template_ds[name].attrs
)

except KeyError:

maybe_repeat = len(feature_idx) if len(feature_idx) > 1 else 0
if maybe_repeat:
# no suplementary dim but several variables
data2grib = np.repeat(nanmask[np.newaxis], maybe_repeat, axis=0)
data2grib[:, latmax : latmin + 1, longmin : longmax + 1] = data
if set(namelist).issubset(set(receiver_ds.keys())):
receiver_ds.update(
{
f: (dims, data2grib[pred.feature_names_to_idx[f]])
for f in feature_names
}
)
else:
receiver_ds.assign(
{
f: (dims, data2grib[pred.feature_names_to_idx[f]])
for f in feature_names
}
)
else:
"only one variable"
data2grib = nanmask
data2grib[latmax : latmin + 1, longmin : longmax + 1] = data
receiver_ds[name] = (dims, data2grib.astype(np.float32))
receiver_ds[name] = receiver_ds[name].assign_attrs(
**template_ds[name].attrs
)
elif tol == group:
# in this case, there might be several variables : basis for nanmask duplication
dims = template_ds.dims
maybe_repeat = len(feature_idx) if len(feature_idx) > 1 else 0
if maybe_repeat:
data2grib = np.repeat(nanmask[np.newaxis], maybe_repeat, axis=0)
data2grib[:, latmax : latmin + 1, longmin : longmax + 1] = data
else:
data2grib = nanmask
data2grib[latmax : latmin + 1, longmin : longmax + 1] = data

receiver_ds.update(
{f: (dims, data2grib[pred.feature_names_to_idx[f]]) for f in feature_names}
)
receiver_ds[name] = receiver_ds[name].assign_attrs(**template_ds[name].attrs)

return receiver_ds


def get_output_filename(
saving_settings: GribSavingSettings, sample: Any, leadtime: float
) -> str:
Expand Down
Loading
Loading