From 9afaf6e9a73c2cc814064814b62e58532653d27b Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 13 Nov 2024 14:14:24 +0100 Subject: [PATCH] fix bugs introduced with dimension order during stack/unstack --- neural_lam/datastore/base.py | 36 +++++++++++++++------- neural_lam/datastore/mdp.py | 3 +- neural_lam/datastore/npyfilesmeps/store.py | 13 ++++---- tests/test_config.py | 2 +- 4 files changed, 34 insertions(+), 20 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 8784ff0..f24d7b3 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -330,7 +330,7 @@ def state_feature_weights_values(self) -> List[float]: pass @functools.lru_cache - def expected_dim_order(self, category: str) -> List[str]: + def expected_dim_order(self, category: str = None) -> List[str]: """ Return the expected dimension order for the dataarray or dataset returned by `get_dataarray` for the given category of data. The @@ -342,6 +342,9 @@ def expected_dim_order(self, category: str) -> List[str]: from `get_dataarray`. And also ensures that downstream uses of a datastore (e.g. WeatherDataset) sees the data in a common structure. + If the category is None, then the it assumed that data only represents + a 1D scalar field varying with grid-index. + Parameters ---------- category : str @@ -353,13 +356,24 @@ def expected_dim_order(self, category: str) -> List[str]: The expected dimension order for the dataarray or dataset. """ - dim_order = ["grid_index", f"{category}_feature"] - if self.is_forecast: - dim_order.extend(["analysis_time", "elapsed_forecast_duration"]) - elif not self.is_forecast: - dim_order.append("time") - if self.is_ensemble: - dim_order.append("ensemble_member") + dim_order = ["grid_index"] + + if category is not None: + dim_order.append(f"{category}_feature") + + if category != "static": + # static data does not vary in time + if self.is_forecast: + dim_order.extend( + ["analysis_time", "elapsed_forecast_duration"] + ) + elif not self.is_forecast: + dim_order.append("time") + + if self.is_ensemble and category == "state": + # XXX: for now we only assume ensemble data for state variables + dim_order.append("ensemble_member") + return dim_order @@ -498,8 +512,8 @@ def stack_grid_coords( da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS) - # infer what category of data by finding the dimension named in the - # format `{category}_feature` + # infer what category of data the array represents by finding the + # dimension named in the format `{category}_feature` category = None for dim in da_or_ds_stacked.dims: if dim.endswith("_feature"): @@ -512,7 +526,7 @@ def stack_grid_coords( dim_order = self.expected_dim_order(category=category) - return da_or_ds_stacked.transpose(dim_order) + return da_or_ds_stacked.transpose(*dim_order) @property @functools.lru_cache diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index f698ddf..df50771 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -281,7 +281,8 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: ) da_category = da_category.sel(time=slice(t_start, t_end)) - return da_category.transpose(self.expected_dim_order(category=category)) + dim_order = self.expected_dim_order(category=category) + return da_category.transpose(*dim_order) def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 37693ea..5fd2d63 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -282,15 +282,16 @@ def get_dataarray(self, category: str, split: str) -> DataArray: features=features, split=split ) das.append(da) - da = xr.concat(das, dim="feature").transpose( - "grid_index", "feature" - ) + da = xr.concat(das, dim="feature") else: raise NotImplementedError(category) da = da.rename(dict(feature=f"{category}_feature")) + # stack the [x, y] dimensions into a `grid_index` dimension + da = self.stack_grid_coords(da) + # check that we have the right features actual_features = da[f"{category}_feature"].values.tolist() expected_features = self.get_vars_names(category=category) @@ -299,7 +300,8 @@ def get_dataarray(self, category: str, split: str) -> DataArray: f"Expected features {expected_features}, got {actual_features}" ) - da = da.transpose(self.expected_dim_order(category=category)) + dim_order = self.expected_dim_order(category=category) + da = da.transpose(*dim_order) return da @@ -501,9 +503,6 @@ def _get_single_timeseries_dataarray( da = xr.DataArray(arr_all, dims=dims, coords=coords) - # stack the [x, y] dimensions into a `grid_index` dimension - da = self.stack_grid_coords(da) - return da def _get_analysis_times(self, split) -> List[np.datetime64]: diff --git a/tests/test_config.py b/tests/test_config.py index 4bb7c1c..1ff40bc 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -46,7 +46,7 @@ def test_config_serialization(state_weighting_config): training: state_feature_weighting: __config_class__: ManualStateFeatureWeighting - values: + weights: u100m: 1.0 v100m: 1.0 """