diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index b0055e3..f029165 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -186,9 +186,36 @@ def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ pass + def _standardize_datarray( + self, da: xr.DataArray, category: str + ) -> xr.DataArray: + """ + Helper function to standardize a dataarray before returning it. + + Parameters + ---------- + da: xr.DataArray + The dataarray to standardize + category : str + The category of the dataarray (state/forcing/static), to load + standardization statistics for. + + Returns + ------- + xr.Dataarray + The standardized dataarray + """ + + standard_da = self.get_standardization_dataarray(category=category) + + mean = standard_da[f"{category}_mean"] + std = standard_da[f"{category}_std"] + + return (da - mean) / std + @abc.abstractmethod def get_dataarray( - self, category: str, split: str + self, category: str, split: str, standardize: bool = False ) -> Union[xr.DataArray, None]: """ Return the processed data (as a single `xr.DataArray`) for the given @@ -219,6 +246,8 @@ def get_dataarray( The category of the dataset (state/forcing/static). split : str The time split to filter the dataset (train/val/test). + standardize: bool + If the dataarray should be returned standardized Returns ------- diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 0d1aac7..0b6bb5e 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -218,7 +218,9 @@ def get_num_data_vars(self, category: str) -> int: """ return len(self.get_vars_names(category)) - def get_dataarray(self, category: str, split: str) -> xr.DataArray: + def get_dataarray( + self, category: str, split: str, standardize: bool = False + ) -> xr.DataArray: """ Return the processed data (as a single `xr.DataArray`) for the given category of data and test/train/val-split that covers all the data (in @@ -246,6 +248,8 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: The category of the dataset (state/forcing/static). split : str The time split to filter the dataset (train/val/test). + standardize: bool + If the dataarray should be returned standardized Returns ------- @@ -283,7 +287,12 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray: da_category = da_category.sel(time=slice(t_start, t_end)) dim_order = self.expected_dim_order(category=category) - return da_category.transpose(*dim_order) + da_category = da_category.transpose(*dim_order) + + if standardize: + return self._standardize_datarray(da_category, category=category) + + return da_category def get_standardization_dataarray(self, category: str) -> xr.Dataset: """ diff --git a/neural_lam/datastore/npyfilesmeps/store.py b/neural_lam/datastore/npyfilesmeps/store.py index 42e8070..8f926f7 100644 --- a/neural_lam/datastore/npyfilesmeps/store.py +++ b/neural_lam/datastore/npyfilesmeps/store.py @@ -199,7 +199,9 @@ def config(self) -> NpyDatastoreConfig: """ return self._config - def get_dataarray(self, category: str, split: str) -> DataArray: + def get_dataarray( + self, category: str, split: str, standardize: bool = False + ) -> DataArray: """ Get the data array for the given category and split of data. If the category is 'state', the data array will be a concatenation of the data @@ -214,6 +216,8 @@ def get_dataarray(self, category: str, split: str) -> DataArray: split : str The dataset split to load the data for. One of 'train', 'val', or 'test'. + standardize: bool + If the dataarray should be returned standardized Returns ------- @@ -303,6 +307,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray: dim_order = self.expected_dim_order(category=category) da = da.transpose(*dim_order) + if standardize: + return self._standardize_datarray(da, category=category) + return da def _get_single_timeseries_dataarray( diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 44baf9c..754cfb3 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -39,8 +39,9 @@ def __init__( self._datastore = datastore num_state_vars = datastore.get_num_data_vars(category="state") num_forcing_vars = datastore.get_num_data_vars(category="forcing") + # Load static features standardized da_static_features = datastore.get_dataarray( - category="static", split=None + category="static", split=None, standardize=True ) da_state_stats = datastore.get_standardization_dataarray( category="state" @@ -49,14 +50,10 @@ def __init__( num_past_forcing_steps = args.num_past_forcing_steps num_future_forcing_steps = args.num_future_forcing_steps - # Load static features for grid/data, NB: self.predict_step assumes - # dimension order to be (grid_index, static_feature) - arr_static = da_static_features.transpose( - "grid_index", "static_feature" - ).values + # Load static features for grid/data, self.register_buffer( "grid_static_features", - torch.tensor(arr_static, dtype=torch.float32), + torch.tensor(da_static_features.values, dtype=torch.float32), persistent=False, )