Skip to content

Commit

Permalink
Implement standardization of static features
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Dec 7, 2024
1 parent 71cfdf9 commit 0805ff6
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 11 deletions.
31 changes: 30 additions & 1 deletion neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,36 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""
pass

def _standardize_datarray(
self, da: xr.DataArray, category: str
) -> xr.DataArray:
"""
Helper function to standardize a dataarray before returning it.
Parameters
----------
da: xr.DataArray
The dataarray to standardize
category : str
The category of the dataarray (state/forcing/static), to load
standardization statistics for.
Returns
-------
xr.Dataarray
The standardized dataarray
"""

standard_da = self.get_standardization_dataarray(category=category)

mean = standard_da[f"{category}_mean"]
std = standard_da[f"{category}_std"]

return (da - mean) / std

@abc.abstractmethod
def get_dataarray(
self, category: str, split: str
self, category: str, split: str, standardize: bool = False
) -> Union[xr.DataArray, None]:
"""
Return the processed data (as a single `xr.DataArray`) for the given
Expand Down Expand Up @@ -219,6 +246,8 @@ def get_dataarray(
The category of the dataset (state/forcing/static).
split : str
The time split to filter the dataset (train/val/test).
standardize: bool
If the dataarray should be returned standardized
Returns
-------
Expand Down
13 changes: 11 additions & 2 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def get_num_data_vars(self, category: str) -> int:
"""
return len(self.get_vars_names(category))

def get_dataarray(self, category: str, split: str) -> xr.DataArray:
def get_dataarray(
self, category: str, split: str, standardize: bool = False
) -> xr.DataArray:
"""
Return the processed data (as a single `xr.DataArray`) for the given
category of data and test/train/val-split that covers all the data (in
Expand Down Expand Up @@ -246,6 +248,8 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
The category of the dataset (state/forcing/static).
split : str
The time split to filter the dataset (train/val/test).
standardize: bool
If the dataarray should be returned standardized
Returns
-------
Expand Down Expand Up @@ -283,7 +287,12 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
da_category = da_category.sel(time=slice(t_start, t_end))

dim_order = self.expected_dim_order(category=category)
return da_category.transpose(*dim_order)
da_category = da_category.transpose(*dim_order)

if standardize:
return self._standardize_datarray(da_category, category=category)

return da_category

def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""
Expand Down
9 changes: 8 additions & 1 deletion neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def config(self) -> NpyDatastoreConfig:
"""
return self._config

def get_dataarray(self, category: str, split: str) -> DataArray:
def get_dataarray(
self, category: str, split: str, standardize: bool = False
) -> DataArray:
"""
Get the data array for the given category and split of data. If the
category is 'state', the data array will be a concatenation of the data
Expand All @@ -214,6 +216,8 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
split : str
The dataset split to load the data for. One of 'train', 'val', or
'test'.
standardize: bool
If the dataarray should be returned standardized
Returns
-------
Expand Down Expand Up @@ -303,6 +307,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
dim_order = self.expected_dim_order(category=category)
da = da.transpose(*dim_order)

if standardize:
return self._standardize_datarray(da, category=category)

return da

def _get_single_timeseries_dataarray(
Expand Down
11 changes: 4 additions & 7 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def __init__(
self._datastore = datastore
num_state_vars = datastore.get_num_data_vars(category="state")
num_forcing_vars = datastore.get_num_data_vars(category="forcing")
# Load static features standardized
da_static_features = datastore.get_dataarray(
category="static", split=None
category="static", split=None, standardize=True
)
da_state_stats = datastore.get_standardization_dataarray(
category="state"
Expand All @@ -49,14 +50,10 @@ def __init__(
num_past_forcing_steps = args.num_past_forcing_steps
num_future_forcing_steps = args.num_future_forcing_steps

# Load static features for grid/data, NB: self.predict_step assumes
# dimension order to be (grid_index, static_feature)
arr_static = da_static_features.transpose(
"grid_index", "static_feature"
).values
# Load static features for grid/data,
self.register_buffer(
"grid_static_features",
torch.tensor(arr_static, dtype=torch.float32),
torch.tensor(da_static_features.values, dtype=torch.float32),
persistent=False,
)

Expand Down

0 comments on commit 0805ff6

Please sign in to comment.