From 743d7a1fdee785da8a890617967f3c16cb6d87b4 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 12 Sep 2024 11:21:07 +0200 Subject: [PATCH] fix coord issues and add datastore example plotting cli --- neural_lam/create_graph.py | 9 +- neural_lam/datastore/base.py | 42 ++--- neural_lam/datastore/mllam.py | 50 +++++- neural_lam/datastore/npyfiles/store.py | 48 +++--- neural_lam/datastore/plot_example.py | 150 ++++++++++++++++++ neural_lam/weather_dataset.py | 207 +++++++++++++++++++++---- tests/test_datasets.py | 78 +++++++++- tests/test_datastores.py | 23 ++- 8 files changed, 515 insertions(+), 92 deletions(-) create mode 100644 neural_lam/datastore/plot_example.py diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index 4ce0811b..0b267f67 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -13,9 +13,8 @@ from torch_geometric.utils.convert import from_networkx # Local +from .datastore import DATASTORES from .datastore.base import BaseCartesianDatastore -from .datastore.mllam import MLLAMDatastore -from .datastore.npyfiles import NpyFilesDatastore def plot_graph(graph, title=None): @@ -532,12 +531,6 @@ def create_graph( save_edges(pyg_m2g, "m2g", graph_dir_path) -DATASTORES = dict( - mllam=MLLAMDatastore, - npyfiles=NpyFilesDatastore, -) - - def create_graph_from_datastore( datastore: BaseCartesianDatastore, output_root_path: str, diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 3a943c18..8e6d6e8d 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -9,6 +9,7 @@ import cartopy.crs as ccrs import numpy as np import xarray as xr +from pandas.core.indexes.multi import MultiIndex class BaseDatastore(abc.ABC): @@ -228,21 +229,13 @@ class CartesianGridShape: class BaseCartesianDatastore(BaseDatastore): - """Base class for weather - data stored on a Cartesian - grid. In addition to the - methods and attributes - required for weather data - in general (see - `BaseDatastore`) for - Cartesian gridded source - data each `grid_index` - coordinate value is assume - to have an associated `x` - and `y`-value so that the - processed data-arrays can - be reshaped back into into - 2D xy-gridded arrays. + """ + Base class for weather data stored on a Cartesian grid. In addition to the + methods and attributes required for weather data in general (see + `BaseDatastore`) for Cartesian gridded source data each `grid_index` + coordinate value is assume to have an associated `x` and `y`-value so that + the processed data-arrays can be reshaped back into into 2D xy-gridded + arrays. In addition the following attributes and methods are required: - `coords_projection` (property): Projection object for the coordinates. @@ -253,7 +246,7 @@ class BaseCartesianDatastore(BaseDatastore): """ - CARTESIAN_COORDS = ["y", "x"] + CARTESIAN_COORDS = ["x", "y"] @property @abc.abstractmethod @@ -347,9 +340,20 @@ def unstack_grid_coords( The dataarray or dataset with the grid coordinates unstacked. """ - return da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS).unstack( - "grid_index" - ) + # check whether `grid_index` is a multi-index + if not isinstance(da_or_ds.indexes.get("grid_index"), MultiIndex): + da_or_ds = da_or_ds.set_index(grid_index=self.CARTESIAN_COORDS) + + da_or_ds_unstacked = da_or_ds.unstack("grid_index") + + # ensure that the x, y dimensions are in the correct order + dims = da_or_ds_unstacked.dims + xy_dim_order = [d for d in dims if d in self.CARTESIAN_COORDS] + + if xy_dim_order != self.CARTESIAN_COORDS: + da_or_ds_unstacked = da_or_ds_unstacked.transpose("y", "x") + + return da_or_ds_unstacked def stack_grid_coords( self, da_or_ds: Union[xr.DataArray, xr.Dataset] diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py index 15886b9e..b0867d02 100644 --- a/neural_lam/datastore/mllam.py +++ b/neural_lam/datastore/mllam.py @@ -70,6 +70,19 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): if len(self.get_vars_names(category)) > 0: print(f"{category}: {' '.join(self.get_vars_names(category))}") + # find out the dimension order for the stacking to grid-index + dim_order = None + for input_dataset in self._config.inputs.values(): + dim_order_ = input_dataset.dim_mapping["grid_index"].dims + if dim_order is None: + dim_order = dim_order_ + else: + assert ( + dim_order == dim_order_ + ), "all inputs must have the same dimension order" + + self.CARTESIAN_COORDS = dim_order + @property def root_path(self) -> Path: """The root path of the dataset. @@ -202,6 +215,14 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: da_category = self._ds[category] + # set units on x y coordinates if missing + for coord in ["x", "y"]: + if "units" not in da_category[coord].attrs: + da_category[coord].attrs["units"] = "m" + + # set multi-index for grid-index + da_category = da_category.set_index(grid_index=self.CARTESIAN_COORDS) + if "time" not in da_category.dims: return da_category else: @@ -294,10 +315,26 @@ def coords_projection(self) -> ccrs.Projection: The projection of the coordinates. """ - # TODO: danra doesn't contain projection information yet, but the next - # version will for now we hardcode the projection - # XXX: this is wrong - return ccrs.PlateCarree() + # XXX: this should move to config + kwargs = { + "LoVInDegrees": 25.0, + "LaDInDegrees": 56.7, + "Latin1InDegrees": 56.7, + "Latin2InDegrees": 56.7, + } + + lon_0 = kwargs["LoVInDegrees"] # Latitude of first standard parallel + lat_0 = kwargs["LaDInDegrees"] # Latitude of second standard parallel + lat_1 = kwargs["Latin1InDegrees"] # Origin latitude + lat_2 = kwargs["Latin2InDegrees"] # Origin longitude + + crs = ccrs.LambertConformal( + central_longitude=lon_0, + central_latitude=lat_0, + standard_parallels=(lat_1, lat_2), + ) + + return crs @property def grid_shape_state(self): @@ -346,10 +383,11 @@ def get_xy(self, category: str, stacked: bool) -> ndarray: da_xy = xr.concat([da_x, da_y], dim="grid_coord") if stacked: - da_xy = da_xy.stack(grid_index=("y", "x")).transpose( + da_xy = da_xy.stack(grid_index=self.CARTESIAN_COORDS).transpose( "grid_coord", "grid_index" ) else: - da_xy = da_xy.transpose("grid_coord", "y", "x") + dims = ["grid_coord", "y", "x"] + da_xy = da_xy.transpose(*dims) return da_xy.values diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py index 6b2e72f4..9f4d90e4 100644 --- a/neural_lam/datastore/npyfiles/store.py +++ b/neural_lam/datastore/npyfiles/store.py @@ -347,8 +347,7 @@ def _get_single_timeseries_dataarray( "Member can only be specified for the 'state' category" ) - # XXX: we here assume that the grid shape is the same for all categories - grid_shape = self.grid_shape_state + concat_axis = 0 file_params = {} add_feature_dim = False @@ -387,7 +386,8 @@ def _get_single_timeseries_dataarray( fp_samples = self.root_path / "static" elif features == ["x", "y"]: filename_format = "nwp_xy.npy" - file_dims = ["y", "x", "feature"] + # NB: for x, y the feature dimension is the first one + file_dims = ["feature", "y", "x"] features_vary_with_analysis_time = False # XXX: x, y are the same for all splits, and so saved in static/ fp_samples = self.root_path / "static" @@ -403,6 +403,12 @@ def _get_single_timeseries_dataarray( coords = {} arr_shape = [] + + xs, ys = self.get_xy(category="state", stacked=False) + assert np.all(xs[0, :] == xs[-1, :]) + assert np.all(ys[:, 0] == ys[:, -1]) + x = xs[0, :] + y = ys[:, 0] for d in dims: if d == "elapsed_forecast_duration": coord_values = ( @@ -413,9 +419,9 @@ def _get_single_timeseries_dataarray( elif d == "analysis_time": coord_values = self._get_analysis_times(split=split) elif d == "y": - coord_values = np.arange(grid_shape.y) + coord_values = y elif d == "x": - coord_values = np.arange(grid_shape.x) + coord_values = x elif d == "feature": coord_values = features else: @@ -450,7 +456,7 @@ def _get_single_timeseries_dataarray( ] if features_vary_with_analysis_time: - arr_all = dask.array.stack(arrays, axis=0) + arr_all = dask.array.stack(arrays, axis=concat_axis) else: arr_all = arrays[0] @@ -568,17 +574,17 @@ def get_xy(self, category: str, stacked: bool) -> np.ndarray: Returns ------- np.ndarray - The x, y coordinates of the dataset, returned differently based on - the value of `stacked`: + The x, y coordinates of the dataset (with x first then y second), + returned differently based on the value of `stacked`: - `stacked==True`: shape `(2, n_grid_points)` where n_grid_points=N_x*N_y. - `stacked==False`: shape `(2, N_y, N_x)` """ - # the array on disk has shape [2, N_x, N_y], but we want to return it - # as [2, N_y, N_x] so we swap the axes - arr = np.load(self.root_path / "static" / "nwp_xy.npy").swapaxes(1, 2) + # the array on disk has shape [2, N_y, N_x], with the first dimension + # being [x, y] + arr = np.load(self.root_path / "static" / "nwp_xy.npy") assert arr.shape[0] == 2, "Expected 2D array" grid_shape = self.grid_shape_state @@ -611,7 +617,7 @@ def grid_shape_state(self) -> CartesianGridShape: The shape of the cartesian grid for the state variables. """ - nx, ny = self.config.grid_shape_state + ny, nx = self.config.grid_shape_state return CartesianGridShape(x=nx, y=ny) @property @@ -626,10 +632,10 @@ def boundary_mask(self) -> xr.DataArray: """ xs, ys = self.get_xy(category="state", stacked=False) - assert np.all(xs[:, 0] == xs[:, -1]) - assert np.all(ys[0, :] == ys[-1, :]) - x = xs[:, 0] - y = ys[0, :] + assert np.all(xs[0, :] == xs[-1, :]) + assert np.all(ys[:, 0] == ys[:, -1]) + x = xs[0, :] + y = ys[:, 0] values = np.load(self.root_path / "static" / "border_mask.npy") da_mask = xr.DataArray( values, dims=["y", "x"], coords=dict(x=x, y=y), name="boundary_mask" @@ -677,11 +683,11 @@ def load_pickled_tensor(fn): std_values = np.array([flux_std, 1.0, 1.0, 1.0, 1.0, 1.0]) elif category == "static": - ds_static = self.get_dataarray(category="static", split="train") - ds_static_mean = ds_static.mean(dim=["grid_index"]) - ds_static_std = ds_static.std(dim=["grid_index"]) - mean_values = ds_static_mean["static_feature"].values - std_values = ds_static_std["static_feature"].values + da_static = self.get_dataarray(category="static", split="train") + da_static_mean = da_static.mean(dim=["grid_index"]).compute() + da_static_std = da_static.std(dim=["grid_index"]).compute() + mean_values = da_static_mean.values + std_values = da_static_std.values else: raise NotImplementedError(f"Category {category} not supported") diff --git a/neural_lam/datastore/plot_example.py b/neural_lam/datastore/plot_example.py new file mode 100644 index 00000000..53bc6d5e --- /dev/null +++ b/neural_lam/datastore/plot_example.py @@ -0,0 +1,150 @@ +# Third-party +import matplotlib.pyplot as plt + + +def plot_example_from_datastore( + category, datastore, col_dim, split="train", standardize=True, selection={} +): + """ + Create a plot of the data from the datastore. + + Parameters + ---------- + category : str + Category of data to plot, one of "state", "forcing", or "static". + datastore : Datastore + Datastore to retrieve data from. + col_dim : str + Dimension to use for plot facetting into columns. This can be a + template string that can be formatted with the category name. + split : str, optional + Split of data to plot, by default "train". + standardize : bool, optional + Whether to standardize the data before plotting, by default True. + selection : dict, optional + Selections to apply to the dataarray, for example + `time="1990-09-03T0:00" would select this single timestep, by default + {}. + + Returns + ------- + Figure + Matplotlib figure object. + """ + da = datastore.get_dataarray(category=category, split=split) + if standardize: + da_stats = datastore.get_standardization_dataarray(category=category) + da = (da - da_stats[f"{category}_mean"]) / da_stats[f"{category}_std"] + da = datastore.unstack_grid_coords(da) + + if len(selection) > 0: + da = da.sel(**selection) + + col = col_dim.format(category=category) + + # check that the column dimension exists and that the resulting shape is 2D + if col not in da.dims: + raise ValueError(f"Column dimension {col} not found in dataarray.") + if not len(da.isel({col: 0}).squeeze().shape) == 2: + raise ValueError( + f"Column dimension {col} and selection {selection} does not " + "result in a 2D dataarray. Please adjust the column dimension " + "and/or selection." + ) + + crs = datastore.coords_projection + g = da.plot( + x="x", + y="y", + col=col, + col_wrap=min(4, int(da[col].count())), + subplot_kws={"projection": crs}, + transform=crs, + size=4, + ) + for ax in g.axes.flat: + ax.coastlines() + ax.gridlines(draw_labels=["left", "bottom"]) + + return g.fig + + +if __name__ == "__main__": + # Standard library + import argparse + + # Local + from . import init_datastore + + def _parse_dict(arg_str): + key, value = arg_str.split("=") + for op in [int, float]: + try: + value = op(value) + break + except ValueError: + pass + return key, value + + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("datastore_kind", help="Kind of datastore to use.") + parser.add_argument( + "config_path", help="Path to the datastore configuration file." + ) + parser.add_argument( + "--category", + default="state", + help="Category of data to plot", + choices=["state", "forcing", "static"], + ) + parser.add_argument( + "--split", default="train", help="Split of data to plot" + ) + parser.add_argument( + "--col-dim", + default="{category}_feature", + help="Dimension to use for plot facetting into columns", + ) + parser.add_argument( + "--disable-standardize", + dest="standardize", + action="store_false", + help="Disable standardization of data", + ) + # add the ability to create dictionary of kwargs + parser.add_argument( + "--selection", + nargs="+", + default=[], + type=_parse_dict, + help=( + "Selections to apply to the dataarray, for example " + '`time="1990-09-03T0:00" would select this single timestep', + ), + ) + args = parser.parse_args() + + selection = dict(args.selection) + + # check that column dimension is not in the selection + if args.col_dim.format(category=args.category) in selection: + raise ValueError( + f"Column dimension {args.col_dim.format(category=args.category)} " + f"cannot be in the selection ({selection}). Please adjust the " + "column dimension and/or selection." + ) + + datastore = init_datastore( + datastore_kind=args.datastore_kind, config_path=args.config_path + ) + plot_example_from_datastore( + args.category, + datastore, + split=args.split, + col_dim=args.col_dim, + standardize=args.standardize, + selection=selection, + ) + plt.show() diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 72fa5d54..ed330e29 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -1,7 +1,10 @@ # Standard library +import datetime import warnings +from typing import Union # Third-party +import numpy as np import pytorch_lightning as pl import torch import xarray as xr @@ -141,32 +144,26 @@ def _sample_time(self, da, idx, n_steps: int, n_timesteps_offset: int = 0): ) return da - def __getitem__(self, idx): + def _build_item_dataarrays(self, idx): """ - Return a single training sample, which consists of the initial states, - target states, forcing and batch times. - - The implementation currently uses xarray.DataArray objects for the - standardization (scaling to mean 0.0 and standard deviation of 1.0) so - that we can make us of xarray's broadcasting capabilities. This makes - it possible to standardization with both global means, but also for - example where a grid-point mean has been computed. This code will have - to be replace if standardization is to be done on the GPU to handle - different shapes of the standardization. + Create the dataarrays for the initial states, target states and forcing + data for the sample at index `idx`. Parameters ---------- idx : int - The index of the sample to return, this will refer to the time of - the initial state. + The index of the sample to create the dataarrays for. Returns ------- - init_states : TrainingSample - A training sample object containing the initial states, target - states, forcing and batch times. The batch times are the times of - the target steps. - + da_init_states : xr.DataArray + The dataarray for the initial states. + da_target_states : xr.DataArray + The dataarray for the target states. + da_forcing_windowed : xr.DataArray + The dataarray for the forcing data, windowed for the sample. + da_target_times : xr.DataArray + The dataarray for the target times. """ # handling ensemble data if self.datastore.is_ensemble: @@ -230,7 +227,7 @@ def __getitem__(self, idx): da_init_states = da_state.isel(time=slice(None, 2)) da_target_states = da_state.isel(time=slice(2, None)) - batch_times = da_target_states.time.values.astype(float) + da_target_times = da_target_states.time if self.standardize: da_init_states = ( @@ -251,29 +248,81 @@ def __getitem__(self, idx): da_forcing_windowed = da_forcing_windowed.stack( forcing_feature_windowed=("forcing_feature", "window_sample") ) + else: + # create an empty forcing tensor with the right shape + da_forcing_windowed = xr.DataArray( + data=np.empty( + (self.ar_steps, da_state.grid_index.size, 0), + ), + dims=("time", "grid_index", "forcing_feature"), + coords={ + "time": da_target_times, + "grid_index": da_state.grid_index, + "forcing_feature": [], + }, + ) + + return ( + da_init_states, + da_target_states, + da_forcing_windowed, + da_target_times, + ) + + def __getitem__(self, idx): + """ + Return a single training sample, which consists of the initial states, + target states, forcing and batch times. + + The implementation currently uses xarray.DataArray objects for the + standardization (scaling to mean 0.0 and standard deviation of 1.0) so + that we can make us of xarray's broadcasting capabilities. This makes + it possible to standardization with both global means, but also for + example where a grid-point mean has been computed. This code will have + to be replace if standardization is to be done on the GPU to handle + different shapes of the standardization. + + Parameters + ---------- + idx : int + The index of the sample to return, this will refer to the time of + the initial state. + + Returns + ------- + init_states : TrainingSample + A training sample object containing the initial states, target + states, forcing and batch times. The batch times are the times of + the target steps. + + """ + ( + da_init_states, + da_target_states, + da_forcing_windowed, + da_target_times, + ) = self._build_item_dataarrays(idx=idx) + + tensor_dtype = torch.float32 - init_states = torch.tensor(da_init_states.values, dtype=torch.float32) + init_states = torch.tensor(da_init_states.values, dtype=tensor_dtype) target_states = torch.tensor( - da_target_states.values, dtype=torch.float32 + da_target_states.values, dtype=tensor_dtype ) - if self.da_forcing is None: - # create an empty forcing tensor - forcing = torch.empty( - (self.ar_steps, da_state.grid_index.size, 0), - dtype=torch.float32, - ) - else: - forcing = torch.tensor( - da_forcing_windowed.values, dtype=torch.float32 - ) + target_times = torch.tensor( + da_target_times.astype("datetime64[ns]").astype("int64").values, + dtype=torch.int64, + ) + + forcing = torch.tensor(da_forcing_windowed.values, dtype=tensor_dtype) # init_states: (2, N_grid, d_features) # target_states: (ar_steps, N_grid, d_features) # forcing: (ar_steps, N_grid, d_windowed_forcing) - # batch_times: (ar_steps,) + # target_times: (ar_steps,) - return init_states, target_states, forcing, batch_times + return init_states, target_states, forcing, target_times def __iter__(self): """ @@ -286,6 +335,98 @@ def __iter__(self): for i in range(len(self)): yield self[i] + def create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: Union[datetime.datetime, list[datetime.datetime]], + category: str, + ): + """ + Construct a xarray.DataArray from a `pytorch.Tensor` with coordinates + for `grid_index`, `time` and `{category}_feature` matching the shape + and number of times provided and add the x/y coordinates from the + datastore. + + The number if times provided is expected to match the shape of the + tensor. For a 2D tensor, the dimensions are assumed to be (grid_index, + {category}_feature) and only a single time should be provided. For a 3D + tensor, the dimensions are assumed to be (time, grid_index, + {category}_feature) and a list of times should be provided. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to construct the DataArray from, this assumed to have + the same dimension ordering as returned by the __getitem__ method + (i.e. time, grid_index, {category}_feature). + time : datetime.datetime or list[datetime.datetime] + The time or times of the tensor. + category : str + The category of the tensor, either "state", "forcing" or "static". + + Returns + ------- + da : xr.DataArray + The constructed DataArray. + """ + + def _is_listlike(obj): + # match list, tuple, numpy array + return hasattr(obj, "__iter__") and not isinstance(obj, str) + + add_time_as_dim = False + if len(tensor.shape) == 2: + dims = ["grid_index", f"{category}_feature"] + if _is_listlike(time): + raise ValueError( + "Expected a single time for a 2D tensor with assumed " + "dimensions (grid_index, {category}_feature), but got " + f"{len(time)} times" + ) + elif len(tensor.shape) == 3: + add_time_as_dim = True + dims = ["time", "grid_index", f"{category}_feature"] + if not _is_listlike(time): + raise ValueError( + "Expected a list of times for a 3D tensor with assumed " + "dimensions (time, grid_index, {category}_feature), but " + "got a single time" + ) + else: + raise ValueError( + "Expected tensor to have 2 or 3 dimensions, but got " + f"{len(tensor.shape)}" + ) + + da_datastore_state = getattr(self, f"da_{category}") + da_grid_index = da_datastore_state.grid_index + da_state_feature = da_datastore_state.state_feature + + coords = { + f"{category}_feature": da_state_feature, + "grid_index": da_grid_index, + } + if add_time_as_dim: + coords["time"] = time + + da = xr.DataArray( + tensor.numpy(), + dims=dims, + coords=coords, + ) + + for grid_coord in ["x", "y"]: + if ( + grid_coord in da_datastore_state.coords + and grid_coord not in da.coords + ): + da.coords[grid_coord] = da_datastore_state[grid_coord] + + if not add_time_as_dim: + da.coords["time"] = time + + return da + class WeatherDataModule(pl.LightningDataModule): """DataModule for weather data.""" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index ad03a880..efe2b1c4 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -2,6 +2,7 @@ from pathlib import Path # Third-party +import numpy as np import pytest import torch from conftest import init_datastore_example @@ -15,7 +16,7 @@ @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_dataset_item(datastore_name): +def test_dataset_item_shapes(datastore_name): """Check that the `datastore.get_dataarray` method is implemented. Validate the shapes of the tensors match between the different @@ -42,7 +43,7 @@ def test_dataset_item(datastore_name): # unpack the item, this is the current return signature for # WeatherDataset.__getitem__ - init_states, target_states, forcing, batch_times = item + init_states, target_states, forcing, target_times = item # initial states assert init_states.ndim == 3 @@ -66,8 +67,8 @@ def test_dataset_item(datastore_name): ) # batch times - assert batch_times.ndim == 1 - assert batch_times.shape[0] == N_pred_steps + assert target_times.ndim == 1 + assert target_times.shape[0] == N_pred_steps # try to get the last item of the dataset to ensure slicing and stacking # operations are working as expected and are consistent with the dataset @@ -75,6 +76,75 @@ def test_dataset_item(datastore_name): dataset[len(dataset) - 1] +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_dataset_item_create_dataarray_from_tensor(datastore_name): + datastore = init_datastore_example(datastore_name) + + N_pred_steps = 4 + forcing_window_size = 3 + dataset = WeatherDataset( + datastore=datastore, + split="train", + ar_steps=N_pred_steps, + forcing_window_size=forcing_window_size, + ) + + idx = 0 + + # unpack the item, this is the current return signature for + # WeatherDataset.__getitem__ + _, target_states, _, target_times_arr = dataset[idx] + _, da_target_true, _, da_target_times_true = dataset._build_item_dataarrays( + idx=idx + ) + + target_times = np.array(target_times_arr, dtype="datetime64[ns]") + np.testing.assert_equal(target_times, da_target_times_true.values) + + da_target = dataset.create_dataarray_from_tensor( + tensor=target_states, category="state", time=target_times + ) + + # conversion to torch.float32 may lead to loss of precision + np.testing.assert_allclose( + da_target.values, da_target_true.values, rtol=1e-6 + ) + assert da_target.dims == da_target_true.dims + for dim in da_target.dims: + np.testing.assert_equal( + da_target[dim].values, da_target_true[dim].values + ) + + # test unstacking the grid coordinates + da_target_unstacked = datastore.unstack_grid_coords(da_target) + assert all( + coord_name in da_target_unstacked.coords for coord_name in ["x", "y"] + ) + + # check construction of a single time + da_target_single = dataset.create_dataarray_from_tensor( + tensor=target_states[0], category="state", time=target_times[0] + ) + + # check that the content is the same + # conversion to torch.float32 may lead to loss of precision + np.testing.assert_allclose( + da_target_single.values, da_target_true[0].values, rtol=1e-6 + ) + assert da_target_single.dims == da_target_true[0].dims + for dim in da_target_single.dims: + np.testing.assert_equal( + da_target_single[dim].values, da_target_true[0][dim].values + ) + + # test unstacking the grid coordinates + da_target_single_unstacked = datastore.unstack_grid_coords(da_target_single) + assert all( + coord_name in da_target_single_unstacked.coords + for coord_name in ["x", "y"] + ) + + @pytest.mark.parametrize("split", ["train", "val", "test"]) @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_single_batch(datastore_name, split): diff --git a/tests/test_datastores.py b/tests/test_datastores.py index dea99e96..1388e4e0 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -27,6 +27,7 @@ - [x] `get_xy` (method): Return the x, y coordinates of the dataset. - [x] `coords_projection` (property): Projection object for the coordinates. - [x] `grid_shape_state` (property): Shape of the grid for the state variables. +- [x] `stack_grid_coords` (method): Stack the grid coordinates of the dataset """ @@ -124,7 +125,7 @@ def test_get_normalization_dataarray(datastore_name): datastore = init_datastore_example(datastore_name) for category in ["state", "forcing", "static"]: - ds_stats = datastore.get_normalization_dataarray(category=category) + ds_stats = datastore.get_standardization_dataarray(category=category) # check that the returned object is an xarray DataArray # and that it has the correct variables @@ -295,3 +296,23 @@ def get_grid_shape_state(datastore_name): assert len(grid_shape) == 2 assert all(isinstance(e, int) for e in grid_shape) assert all(e > 0 for e in grid_shape) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_stacking_grid_coords(datastore_name): + """Check that the `datastore.stack_grid_coords` method is implemented.""" + datastore = init_datastore_example(datastore_name) + + if not isinstance(datastore, BaseCartesianDatastore): + pytest.skip("Datastore does not implement `BaseCartesianDatastore`") + + da_static = datastore.get_dataarray("static", split=None) + + da_static_unstacked = datastore.unstack_grid_coords(da_static).load() + da_static_test = datastore.stack_grid_coords(da_static_unstacked) + + # XXX: for the moment unstacking doesn't guarantee the order of the + # dimensions maybe we should enforce this? + da_static_test = da_static_test.transpose(*da_static.dims) + + xr.testing.assert_equal(da_static, da_static_test)