diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index dda5ac0..21a28f3 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -43,8 +43,8 @@ def __init__( category="state" ) da_boundary_mask = datastore.boundary_mask - include_past_forcing = args.include_past_forcing - include_future_forcing = args.include_future_forcing + num_past_forcing_steps = args.num_past_forcing_steps + num_future_forcing_steps = args.num_future_forcing_steps # Load static features for grid/data, NB: self.predict_step assumes # dimension order to be (grid_index, static_feature) @@ -109,7 +109,7 @@ def __init__( 2 * self.grid_output_dim + grid_static_dim + num_forcing_vars - * (include_past_forcing + include_future_forcing + 1) + * (num_past_forcing_steps + num_future_forcing_steps + 1) ) # Instantiate loss function diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 8f400b3..ca0aba4 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -192,13 +192,13 @@ def main(input_args=None): metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""", ) parser.add_argument( - "--include_past_forcing", + "--num_past_forcing_steps", type=int, default=1, help="Number of past time steps to use as input for forcing data", ) parser.add_argument( - "--include_future_forcing", + "--num_future_forcing_steps", type=int, default=1, help="Number of future time steps to use as input for forcing data", @@ -232,8 +232,8 @@ def main(input_args=None): ar_steps_train=args.ar_steps_train, ar_steps_eval=args.ar_steps_eval, standardize=True, - include_past_forcing=args.include_past_forcing, - include_future_forcing=args.include_future_forcing, + num_past_forcing_steps=args.num_past_forcing_steps, + num_future_forcing_steps=args.num_future_forcing_steps, batch_size=args.batch_size, num_workers=args.num_workers, ) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 371a0c3..37af974 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -18,6 +18,20 @@ class WeatherDataset(torch.utils.data.Dataset): This class loads and processes weather data from a given datastore. + Parameters + ---------- + datastore : BaseDatastore + The datastore to load the data from (e.g. mdp). + split : str, optional + The data split to use ("train", "val" or "test"). Default is "train". + ar_steps : int, optional + The number of autoregressive steps. Default is 3. + num_past_forcing_steps : int, optional + The number of past forcing steps to include. Default is 1. + num_future_forcing_steps : int, optional + The number of future forcing steps to include. Default is 1. + standardize : bool, optional + Whether to standardize the data. Default is True. """ def __init__( @@ -25,8 +39,8 @@ def __init__( datastore: BaseDatastore, split="train", ar_steps=3, - include_past_forcing=1, - include_future_forcing=1, + num_past_forcing_steps=1, + num_future_forcing_steps=1, standardize=True, ): super().__init__() @@ -34,8 +48,8 @@ def __init__( self.split = split self.ar_steps = ar_steps self.datastore = datastore - self.include_past_forcing = include_past_forcing - self.include_future_forcing = include_future_forcing + self.num_past_forcing_steps = num_past_forcing_steps + self.num_future_forcing_steps = num_future_forcing_steps self.da_state = self.datastore.get_dataarray( category="state", split=self.split @@ -54,7 +68,7 @@ def __init__( f"configuration used in the `{split}` split. You could try " "either reducing the number of autoregressive steps " "(`ar_steps`) and/or the forcing window size " - "(`include_past_forcing` and `include_future_forcing`)" + "(`num_past_forcing_steps` and `num_future_forcing_steps`)" ) # Set up for standardization @@ -113,21 +127,21 @@ def __len__(self): # Where: # - total time steps: len(self.da_state.time) # - autoregressive steps: self.ar_steps - # - past forcing: max(2, self.include_past_forcing) (at least 2 + # - past forcing: max(2, self.num_past_forcing_steps) (at least 2 # time steps are required for the initial state) - # - future forcing: self.include_future_forcing + # - future forcing: self.num_future_forcing_steps return ( len(self.da_state.time) - self.ar_steps - - max(2, self.include_past_forcing) - - self.include_future_forcing + - max(2, self.num_past_forcing_steps) + - self.num_future_forcing_steps ) def _slice_state_time(self, da_state, idx, n_steps: int): """ Produce a time slice of the given dataarray `da_state` (state) starting at `idx` and with `n_steps` steps. An `offset`is calculated based on the - `include_past_forcing` class attribute. `Offset` is used to offset the + `num_past_forcing_steps` class attribute. `Offset` is used to offset the start of the sample, to assert that enough previous time steps are available for the 2 initial states and any corresponding forcings (calculated in `_slice_forcing_time`). @@ -153,8 +167,8 @@ def _slice_state_time(self, da_state, idx, n_steps: int): # The current implementation requires at least 2 time steps for the # initial state (see GraphCast). init_steps = 2 - start_idx = idx + max(0, self.include_past_forcing - init_steps) - end_idx = idx + max(init_steps, self.include_past_forcing) + n_steps + start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) + end_idx = idx + max(init_steps, self.num_past_forcing_steps) + n_steps # slice the dataarray to include the required number of time steps if self.datastore.is_forecast: # this implies that the data will have both `analysis_time` and @@ -185,7 +199,7 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int): """ Produce a time slice of the given dataarray `da_forcing` (forcing) starting at `idx` and with `n_steps` steps. An `offset` is calculated - based on the `include_past_forcing` class attribute. It is used to + based on the `num_past_forcing_steps` class attribute. It is used to offset the start of the sample, to ensure that enough previous time steps are available for the forcing data. The forcing data is windowed around the current autoregressive time step to include the past and @@ -215,7 +229,7 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int): # as past forcings. init_steps = 2 da_list = [] - offset = idx + max(init_steps, self.include_past_forcing) + offset = idx + max(init_steps, self.num_past_forcing_steps) if self.datastore.is_forecast: # This implies that the data will have both `analysis_time` and @@ -225,8 +239,8 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int): # sample per forecast. # Add a 'time' dimension using the actual forecast times for step in range(n_steps): - start_idx = offset + step - self.include_past_forcing - end_idx = offset + step + self.include_future_forcing + start_idx = offset + step - self.num_past_forcing_steps + end_idx = offset + step + self.num_future_forcing_steps current_time = ( da_forcing.analysis_time[idx] @@ -261,8 +275,8 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int): # offset is only relevant for the very first (and last) samples in # the dataset. for step in range(n_steps): - start_idx = offset + step - self.include_past_forcing - end_idx = offset + step + self.include_future_forcing + start_idx = offset + step - self.num_past_forcing_steps + end_idx = offset + step + self.num_future_forcing_steps # Slice the data over the desired time window da_sliced = da_forcing.isel(time=slice(start_idx, end_idx + 1)) @@ -572,15 +586,15 @@ def __init__( ar_steps_train=3, ar_steps_eval=25, standardize=True, - include_past_forcing=1, - include_future_forcing=1, + num_past_forcing_steps=1, + num_future_forcing_steps=1, batch_size=4, num_workers=16, ): super().__init__() self._datastore = datastore - self.include_past_forcing = include_past_forcing - self.include_future_forcing = include_future_forcing + self.num_past_forcing_steps = num_past_forcing_steps + self.num_future_forcing_steps = num_future_forcing_steps self.ar_steps_train = ar_steps_train self.ar_steps_eval = ar_steps_eval self.standardize = standardize @@ -603,16 +617,16 @@ def setup(self, stage=None): split="train", ar_steps=self.ar_steps_train, standardize=self.standardize, - include_past_forcing=self.include_past_forcing, - include_future_forcing=self.include_future_forcing, + num_past_forcing_steps=self.num_past_forcing_steps, + num_future_forcing_steps=self.num_future_forcing_steps, ) self.val_dataset = WeatherDataset( datastore=self._datastore, split="val", ar_steps=self.ar_steps_eval, standardize=self.standardize, - include_past_forcing=self.include_past_forcing, - include_future_forcing=self.include_future_forcing, + num_past_forcing_steps=self.num_past_forcing_steps, + num_future_forcing_steps=self.num_future_forcing_steps, ) if stage == "test" or stage is None: @@ -621,8 +635,8 @@ def setup(self, stage=None): split="test", ar_steps=self.ar_steps_eval, standardize=self.standardize, - include_past_forcing=self.include_past_forcing, - include_future_forcing=self.include_future_forcing, + num_past_forcing_steps=self.num_past_forcing_steps, + num_future_forcing_steps=self.num_future_forcing_steps, ) def train_dataloader(self): diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 11cad74..9e89b43 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -34,14 +34,14 @@ def test_dataset_item_shapes(datastore_name): N_gridpoints = datastore.num_grid_points N_pred_steps = 4 - include_past_forcing = 1 - include_future_forcing = 1 + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 dataset = WeatherDataset( datastore=datastore, split="train", ar_steps=N_pred_steps, - include_past_forcing=include_past_forcing, - include_future_forcing=include_future_forcing, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, ) item = dataset[0] @@ -67,7 +67,7 @@ def test_dataset_item_shapes(datastore_name): assert forcing.shape[0] == N_pred_steps assert forcing.shape[1] == N_gridpoints assert forcing.shape[2] == datastore.get_num_data_vars("forcing") * ( - include_past_forcing + include_future_forcing + 1 + num_past_forcing_steps + num_future_forcing_steps + 1 ) # batch times @@ -85,14 +85,14 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name): datastore = init_datastore_example(datastore_name) N_pred_steps = 4 - include_past_forcing = 1 - include_future_forcing = 1 + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 dataset = WeatherDataset( datastore=datastore, split="train", ar_steps=N_pred_steps, - include_past_forcing=include_past_forcing, - include_future_forcing=include_future_forcing, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, ) idx = 0 @@ -184,8 +184,8 @@ class ModelArgs: hidden_layers = 1 processor_layers = 4 mesh_aggr = "sum" - include_past_forcing = 1 - include_future_forcing = 1 + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 args = ModelArgs() @@ -248,8 +248,8 @@ def test_dataset_length(dataset_config): datastore=datastore, split="train", ar_steps=dataset_config["ar_steps"], - include_past_forcing=dataset_config["past"], - include_future_forcing=dataset_config["future"], + num_past_forcing_steps=dataset_config["past"], + num_future_forcing_steps=dataset_config["future"], ) # We expect dataset to contain this many samples diff --git a/tests/test_training.py b/tests/test_training.py index 87c9748..7ad28e4 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -64,8 +64,8 @@ def test_training(datastore_name): standardize=True, batch_size=2, num_workers=1, - include_past_forcing=1, - include_future_forcing=1, + num_past_forcing_steps=1, + num_future_forcing_steps=1, ) class ModelArgs: @@ -83,8 +83,8 @@ class ModelArgs: lr = 1.0e-3 val_steps_to_log = [1, 3] metrics_watch = [] - include_past_forcing = 1 - include_future_forcing = 1 + num_past_forcing_steps = 1 + num_future_forcing_steps = 1 model_args = ModelArgs()