diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 418b3b3..8784ff0 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -329,6 +329,39 @@ def state_feature_weights_values(self) -> List[float]: """ pass + @functools.lru_cache + def expected_dim_order(self, category: str) -> List[str]: + """ + Return the expected dimension order for the dataarray or dataset + returned by `get_dataarray` for the given category of data. The + dimension order is the order of the dimensions in the dataarray or + dataset, and is used to check that the data is in the expected format. + + This is necessary so that when stacking and unstacking the spatial grid + we can ensure that the dimension order is the same as what is returned + from `get_dataarray`. And also ensures that downstream uses of a + datastore (e.g. WeatherDataset) sees the data in a common structure. + + Parameters + ---------- + category : str + The category of the dataset (state/forcing/static). + + Returns + ------- + List[str] + The expected dimension order for the dataarray or dataset. + + """ + dim_order = ["grid_index", f"{category}_feature"] + if self.is_forecast: + dim_order.extend(["analysis_time", "elapsed_forecast_duration"]) + elif not self.is_forecast: + dim_order.append("time") + if self.is_ensemble: + dim_order.append("ensemble_member") + return dim_order + @dataclasses.dataclass class CartesianGridShape: @@ -464,29 +497,22 @@ def stack_grid_coords( return da_or_ds da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS) - # find the feature dimension, which has named with the format - # `{category}_feature` - - # ensure that grid_index is the first dimension, and the feature - # dimension is the second - dim_order = ["grid_index"] - - potential_feature_dims = [ - d for d in da_or_ds_stacked.dims if d.endswith("feature") - ] - n_feature_dims = len(potential_feature_dims) - if n_feature_dims == 0: - pass - elif n_feature_dims == 1: - feature_dim = potential_feature_dims[0] - dim_order.append(feature_dim) - else: - raise ValueError( - "Expected exactly one feature dimension in the stacked data, " - f"got {potential_feature_dims}" - ) - - return da_or_ds_stacked.transpose(*dim_order, ...) + + # infer what category of data by finding the dimension named in the + # format `{category}_feature` + category = None + for dim in da_or_ds_stacked.dims: + if dim.endswith("_feature"): + if category is not None: + raise ValueError( + "Multiple dimensions ending with '_feature' found in " + f"dataarray: {da_or_ds_stacked}. Cannot infer category." + ) + category = dim.split("_")[0] + + dim_order = self.expected_dim_order(category=category) + + return da_or_ds_stacked.transpose(dim_order) @property @functools.lru_cache diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 2e438af..f698ddf 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -266,9 +266,7 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: # set multi-index for grid-index da_category = da_category.set_index(grid_index=self.CARTESIAN_COORDS) - if "time" not in da_category.dims: - return da_category - else: + if "time" in da_category.dims: t_start = ( self._ds.splits.sel(split_name=split) .sel(split_part="start") @@ -281,7 +279,9 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: .load() .item() ) - return da_category.sel(time=slice(t_start, t_end)) + da_category = da_category.sel(time=slice(t_start, t_end)) + + return da_category.transpose(self.expected_dim_order(category=category)) def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index ffa70dc..37693ea 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -299,6 +299,8 @@ def get_dataarray(self, category: str, split: str) -> DataArray: f"Expected features {expected_features}, got {actual_features}" ) + da = da.transpose(self.expected_dim_order(category=category)) + return da def _get_single_timeseries_dataarray( @@ -346,7 +348,11 @@ def _get_single_timeseries_dataarray( None, ), "Unknown dataset split" else: - assert split in ("train", "val", "test"), "Unknown dataset split" + assert split in ( + "train", + "val", + "test", + ), f"Unknown dataset split {split} for features {features}" if member is not None and features != self.get_vars_names( category="state" diff --git a/tests/test_datastores.py b/tests/test_datastores.py index c28418c..096efcb 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -319,7 +319,7 @@ def test_stacking_grid_coords(datastore_name, category): if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip("Datastore does not implement `BaseCartesianDatastore`") - da_static = datastore.get_dataarray(category=category, split=None) + da_static = datastore.get_dataarray(category=category, split="train") da_static_unstacked = datastore.unstack_grid_coords(da_static).load() da_static_test = datastore.stack_grid_coords(da_static_unstacked)