Skip to content

Commit

Permalink
add enforcement of datastores output dimension order
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Nov 13, 2024
1 parent d355ef5 commit 1121d9f
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 29 deletions.
72 changes: 49 additions & 23 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,39 @@ def state_feature_weights_values(self) -> List[float]:
"""
pass

@functools.lru_cache
def expected_dim_order(self, category: str) -> List[str]:
"""
Return the expected dimension order for the dataarray or dataset
returned by `get_dataarray` for the given category of data. The
dimension order is the order of the dimensions in the dataarray or
dataset, and is used to check that the data is in the expected format.
This is necessary so that when stacking and unstacking the spatial grid
we can ensure that the dimension order is the same as what is returned
from `get_dataarray`. And also ensures that downstream uses of a
datastore (e.g. WeatherDataset) sees the data in a common structure.
Parameters
----------
category : str
The category of the dataset (state/forcing/static).
Returns
-------
List[str]
The expected dimension order for the dataarray or dataset.
"""
dim_order = ["grid_index", f"{category}_feature"]
if self.is_forecast:
dim_order.extend(["analysis_time", "elapsed_forecast_duration"])
elif not self.is_forecast:
dim_order.append("time")
if self.is_ensemble:
dim_order.append("ensemble_member")
return dim_order


@dataclasses.dataclass
class CartesianGridShape:
Expand Down Expand Up @@ -464,29 +497,22 @@ def stack_grid_coords(
return da_or_ds

da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)
# find the feature dimension, which has named with the format
# `{category}_feature`

# ensure that grid_index is the first dimension, and the feature
# dimension is the second
dim_order = ["grid_index"]

potential_feature_dims = [
d for d in da_or_ds_stacked.dims if d.endswith("feature")
]
n_feature_dims = len(potential_feature_dims)
if n_feature_dims == 0:
pass
elif n_feature_dims == 1:
feature_dim = potential_feature_dims[0]
dim_order.append(feature_dim)
else:
raise ValueError(
"Expected exactly one feature dimension in the stacked data, "
f"got {potential_feature_dims}"
)

return da_or_ds_stacked.transpose(*dim_order, ...)

# infer what category of data by finding the dimension named in the
# format `{category}_feature`
category = None
for dim in da_or_ds_stacked.dims:
if dim.endswith("_feature"):
if category is not None:
raise ValueError(
"Multiple dimensions ending with '_feature' found in "
f"dataarray: {da_or_ds_stacked}. Cannot infer category."
)
category = dim.split("_")[0]

dim_order = self.expected_dim_order(category=category)

return da_or_ds_stacked.transpose(dim_order)

@property
@functools.lru_cache
Expand Down
8 changes: 4 additions & 4 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,7 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
# set multi-index for grid-index
da_category = da_category.set_index(grid_index=self.CARTESIAN_COORDS)

if "time" not in da_category.dims:
return da_category
else:
if "time" in da_category.dims:
t_start = (
self._ds.splits.sel(split_name=split)
.sel(split_part="start")
Expand All @@ -281,7 +279,9 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
.load()
.item()
)
return da_category.sel(time=slice(t_start, t_end))
da_category = da_category.sel(time=slice(t_start, t_end))

return da_category.transpose(self.expected_dim_order(category=category))

def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""
Expand Down
8 changes: 7 additions & 1 deletion neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
f"Expected features {expected_features}, got {actual_features}"
)

da = da.transpose(self.expected_dim_order(category=category))

return da

def _get_single_timeseries_dataarray(
Expand Down Expand Up @@ -346,7 +348,11 @@ def _get_single_timeseries_dataarray(
None,
), "Unknown dataset split"
else:
assert split in ("train", "val", "test"), "Unknown dataset split"
assert split in (
"train",
"val",
"test",
), f"Unknown dataset split {split} for features {features}"

if member is not None and features != self.get_vars_names(
category="state"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_stacking_grid_coords(datastore_name, category):
if not isinstance(datastore, BaseRegularGridDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")

da_static = datastore.get_dataarray(category=category, split=None)
da_static = datastore.get_dataarray(category=category, split="train")

da_static_unstacked = datastore.unstack_grid_coords(da_static).load()
da_static_test = datastore.stack_grid_coords(da_static_unstacked)
Expand Down

0 comments on commit 1121d9f

Please sign in to comment.