Skip to content

Commit

Permalink
fix bugs introduced with dimension order during stack/unstack
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Nov 13, 2024
1 parent 1121d9f commit 9afaf6e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 20 deletions.
36 changes: 25 additions & 11 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def state_feature_weights_values(self) -> List[float]:
pass

@functools.lru_cache
def expected_dim_order(self, category: str) -> List[str]:
def expected_dim_order(self, category: str = None) -> List[str]:
"""
Return the expected dimension order for the dataarray or dataset
returned by `get_dataarray` for the given category of data. The
Expand All @@ -342,6 +342,9 @@ def expected_dim_order(self, category: str) -> List[str]:
from `get_dataarray`. And also ensures that downstream uses of a
datastore (e.g. WeatherDataset) sees the data in a common structure.
If the category is None, then the it assumed that data only represents
a 1D scalar field varying with grid-index.
Parameters
----------
category : str
Expand All @@ -353,13 +356,24 @@ def expected_dim_order(self, category: str) -> List[str]:
The expected dimension order for the dataarray or dataset.
"""
dim_order = ["grid_index", f"{category}_feature"]
if self.is_forecast:
dim_order.extend(["analysis_time", "elapsed_forecast_duration"])
elif not self.is_forecast:
dim_order.append("time")
if self.is_ensemble:
dim_order.append("ensemble_member")
dim_order = ["grid_index"]

if category is not None:
dim_order.append(f"{category}_feature")

if category != "static":
# static data does not vary in time
if self.is_forecast:
dim_order.extend(
["analysis_time", "elapsed_forecast_duration"]
)
elif not self.is_forecast:
dim_order.append("time")

if self.is_ensemble and category == "state":
# XXX: for now we only assume ensemble data for state variables
dim_order.append("ensemble_member")

return dim_order


Expand Down Expand Up @@ -498,8 +512,8 @@ def stack_grid_coords(

da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)

# infer what category of data by finding the dimension named in the
# format `{category}_feature`
# infer what category of data the array represents by finding the
# dimension named in the format `{category}_feature`
category = None
for dim in da_or_ds_stacked.dims:
if dim.endswith("_feature"):
Expand All @@ -512,7 +526,7 @@ def stack_grid_coords(

dim_order = self.expected_dim_order(category=category)

return da_or_ds_stacked.transpose(dim_order)
return da_or_ds_stacked.transpose(*dim_order)

@property
@functools.lru_cache
Expand Down
3 changes: 2 additions & 1 deletion neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
)
da_category = da_category.sel(time=slice(t_start, t_end))

return da_category.transpose(self.expected_dim_order(category=category))
dim_order = self.expected_dim_order(category=category)
return da_category.transpose(*dim_order)

def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""
Expand Down
13 changes: 6 additions & 7 deletions neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,16 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
features=features, split=split
)
das.append(da)
da = xr.concat(das, dim="feature").transpose(
"grid_index", "feature"
)
da = xr.concat(das, dim="feature")

else:
raise NotImplementedError(category)

da = da.rename(dict(feature=f"{category}_feature"))

# stack the [x, y] dimensions into a `grid_index` dimension
da = self.stack_grid_coords(da)

# check that we have the right features
actual_features = da[f"{category}_feature"].values.tolist()
expected_features = self.get_vars_names(category=category)
Expand All @@ -299,7 +300,8 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
f"Expected features {expected_features}, got {actual_features}"
)

da = da.transpose(self.expected_dim_order(category=category))
dim_order = self.expected_dim_order(category=category)
da = da.transpose(*dim_order)

return da

Expand Down Expand Up @@ -501,9 +503,6 @@ def _get_single_timeseries_dataarray(

da = xr.DataArray(arr_all, dims=dims, coords=coords)

# stack the [x, y] dimensions into a `grid_index` dimension
da = self.stack_grid_coords(da)

return da

def _get_analysis_times(self, split) -> List[np.datetime64]:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_config_serialization(state_weighting_config):
training:
state_feature_weighting:
__config_class__: ManualStateFeatureWeighting
values:
weights:
u100m: 1.0
v100m: 1.0
"""
Expand Down

0 comments on commit 9afaf6e

Please sign in to comment.