Skip to content

Commit

Permalink
renaming the forcing arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Nov 13, 2024
1 parent 89b10b5 commit b56e47a
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 52 deletions.
6 changes: 3 additions & 3 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __init__(
category="state"
)
da_boundary_mask = datastore.boundary_mask
include_past_forcing = args.include_past_forcing
include_future_forcing = args.include_future_forcing
num_past_forcing_steps = args.num_past_forcing_steps
num_future_forcing_steps = args.num_future_forcing_steps

# Load static features for grid/data, NB: self.predict_step assumes
# dimension order to be (grid_index, static_feature)
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(
2 * self.grid_output_dim
+ grid_static_dim
+ num_forcing_vars
* (include_past_forcing + include_future_forcing + 1)
* (num_past_forcing_steps + num_future_forcing_steps + 1)
)

# Instantiate loss function
Expand Down
8 changes: 4 additions & 4 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ def main(input_args=None):
metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""",
)
parser.add_argument(
"--include_past_forcing",
"--num_past_forcing_steps",
type=int,
default=1,
help="Number of past time steps to use as input for forcing data",
)
parser.add_argument(
"--include_future_forcing",
"--num_future_forcing_steps",
type=int,
default=1,
help="Number of future time steps to use as input for forcing data",
Expand Down Expand Up @@ -232,8 +232,8 @@ def main(input_args=None):
ar_steps_train=args.ar_steps_train,
ar_steps_eval=args.ar_steps_eval,
standardize=True,
include_past_forcing=args.include_past_forcing,
include_future_forcing=args.include_future_forcing,
num_past_forcing_steps=args.num_past_forcing_steps,
num_future_forcing_steps=args.num_future_forcing_steps,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
Expand Down
70 changes: 42 additions & 28 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,38 @@ class WeatherDataset(torch.utils.data.Dataset):
This class loads and processes weather data from a given datastore.
Parameters
----------
datastore : BaseDatastore
The datastore to load the data from (e.g. mdp).
split : str, optional
The data split to use ("train", "val" or "test"). Default is "train".
ar_steps : int, optional
The number of autoregressive steps. Default is 3.
num_past_forcing_steps : int, optional
The number of past forcing steps to include. Default is 1.
num_future_forcing_steps : int, optional
The number of future forcing steps to include. Default is 1.
standardize : bool, optional
Whether to standardize the data. Default is True.
"""

def __init__(
self,
datastore: BaseDatastore,
split="train",
ar_steps=3,
include_past_forcing=1,
include_future_forcing=1,
num_past_forcing_steps=1,
num_future_forcing_steps=1,
standardize=True,
):
super().__init__()

self.split = split
self.ar_steps = ar_steps
self.datastore = datastore
self.include_past_forcing = include_past_forcing
self.include_future_forcing = include_future_forcing
self.num_past_forcing_steps = num_past_forcing_steps
self.num_future_forcing_steps = num_future_forcing_steps

self.da_state = self.datastore.get_dataarray(
category="state", split=self.split
Expand All @@ -54,7 +68,7 @@ def __init__(
f"configuration used in the `{split}` split. You could try "
"either reducing the number of autoregressive steps "
"(`ar_steps`) and/or the forcing window size "
"(`include_past_forcing` and `include_future_forcing`)"
"(`num_past_forcing_steps` and `num_future_forcing_steps`)"
)

# Set up for standardization
Expand Down Expand Up @@ -113,21 +127,21 @@ def __len__(self):
# Where:
# - total time steps: len(self.da_state.time)
# - autoregressive steps: self.ar_steps
# - past forcing: max(2, self.include_past_forcing) (at least 2
# - past forcing: max(2, self.num_past_forcing_steps) (at least 2
# time steps are required for the initial state)
# - future forcing: self.include_future_forcing
# - future forcing: self.num_future_forcing_steps
return (
len(self.da_state.time)
- self.ar_steps
- max(2, self.include_past_forcing)
- self.include_future_forcing
- max(2, self.num_past_forcing_steps)
- self.num_future_forcing_steps
)

def _slice_state_time(self, da_state, idx, n_steps: int):
"""
Produce a time slice of the given dataarray `da_state` (state) starting
at `idx` and with `n_steps` steps. An `offset`is calculated based on the
`include_past_forcing` class attribute. `Offset` is used to offset the
`num_past_forcing_steps` class attribute. `Offset` is used to offset the
start of the sample, to assert that enough previous time steps are
available for the 2 initial states and any corresponding forcings
(calculated in `_slice_forcing_time`).
Expand All @@ -153,8 +167,8 @@ def _slice_state_time(self, da_state, idx, n_steps: int):
# The current implementation requires at least 2 time steps for the
# initial state (see GraphCast).
init_steps = 2
start_idx = idx + max(0, self.include_past_forcing - init_steps)
end_idx = idx + max(init_steps, self.include_past_forcing) + n_steps
start_idx = idx + max(0, self.num_past_forcing_steps - init_steps)
end_idx = idx + max(init_steps, self.num_past_forcing_steps) + n_steps
# slice the dataarray to include the required number of time steps
if self.datastore.is_forecast:
# this implies that the data will have both `analysis_time` and
Expand Down Expand Up @@ -185,7 +199,7 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int):
"""
Produce a time slice of the given dataarray `da_forcing` (forcing)
starting at `idx` and with `n_steps` steps. An `offset` is calculated
based on the `include_past_forcing` class attribute. It is used to
based on the `num_past_forcing_steps` class attribute. It is used to
offset the start of the sample, to ensure that enough previous time
steps are available for the forcing data. The forcing data is windowed
around the current autoregressive time step to include the past and
Expand Down Expand Up @@ -215,7 +229,7 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int):
# as past forcings.
init_steps = 2
da_list = []
offset = idx + max(init_steps, self.include_past_forcing)
offset = idx + max(init_steps, self.num_past_forcing_steps)

if self.datastore.is_forecast:
# This implies that the data will have both `analysis_time` and
Expand All @@ -225,8 +239,8 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int):
# sample per forecast.
# Add a 'time' dimension using the actual forecast times
for step in range(n_steps):
start_idx = offset + step - self.include_past_forcing
end_idx = offset + step + self.include_future_forcing
start_idx = offset + step - self.num_past_forcing_steps
end_idx = offset + step + self.num_future_forcing_steps

current_time = (
da_forcing.analysis_time[idx]
Expand Down Expand Up @@ -261,8 +275,8 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int):
# offset is only relevant for the very first (and last) samples in
# the dataset.
for step in range(n_steps):
start_idx = offset + step - self.include_past_forcing
end_idx = offset + step + self.include_future_forcing
start_idx = offset + step - self.num_past_forcing_steps
end_idx = offset + step + self.num_future_forcing_steps

# Slice the data over the desired time window
da_sliced = da_forcing.isel(time=slice(start_idx, end_idx + 1))
Expand Down Expand Up @@ -572,15 +586,15 @@ def __init__(
ar_steps_train=3,
ar_steps_eval=25,
standardize=True,
include_past_forcing=1,
include_future_forcing=1,
num_past_forcing_steps=1,
num_future_forcing_steps=1,
batch_size=4,
num_workers=16,
):
super().__init__()
self._datastore = datastore
self.include_past_forcing = include_past_forcing
self.include_future_forcing = include_future_forcing
self.num_past_forcing_steps = num_past_forcing_steps
self.num_future_forcing_steps = num_future_forcing_steps
self.ar_steps_train = ar_steps_train
self.ar_steps_eval = ar_steps_eval
self.standardize = standardize
Expand All @@ -603,16 +617,16 @@ def setup(self, stage=None):
split="train",
ar_steps=self.ar_steps_train,
standardize=self.standardize,
include_past_forcing=self.include_past_forcing,
include_future_forcing=self.include_future_forcing,
num_past_forcing_steps=self.num_past_forcing_steps,
num_future_forcing_steps=self.num_future_forcing_steps,
)
self.val_dataset = WeatherDataset(
datastore=self._datastore,
split="val",
ar_steps=self.ar_steps_eval,
standardize=self.standardize,
include_past_forcing=self.include_past_forcing,
include_future_forcing=self.include_future_forcing,
num_past_forcing_steps=self.num_past_forcing_steps,
num_future_forcing_steps=self.num_future_forcing_steps,
)

if stage == "test" or stage is None:
Expand All @@ -621,8 +635,8 @@ def setup(self, stage=None):
split="test",
ar_steps=self.ar_steps_eval,
standardize=self.standardize,
include_past_forcing=self.include_past_forcing,
include_future_forcing=self.include_future_forcing,
num_past_forcing_steps=self.num_past_forcing_steps,
num_future_forcing_steps=self.num_future_forcing_steps,
)

def train_dataloader(self):
Expand Down
26 changes: 13 additions & 13 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def test_dataset_item_shapes(datastore_name):
N_gridpoints = datastore.num_grid_points

N_pred_steps = 4
include_past_forcing = 1
include_future_forcing = 1
num_past_forcing_steps = 1
num_future_forcing_steps = 1
dataset = WeatherDataset(
datastore=datastore,
split="train",
ar_steps=N_pred_steps,
include_past_forcing=include_past_forcing,
include_future_forcing=include_future_forcing,
num_past_forcing_steps=num_past_forcing_steps,
num_future_forcing_steps=num_future_forcing_steps,
)

item = dataset[0]
Expand All @@ -67,7 +67,7 @@ def test_dataset_item_shapes(datastore_name):
assert forcing.shape[0] == N_pred_steps
assert forcing.shape[1] == N_gridpoints
assert forcing.shape[2] == datastore.get_num_data_vars("forcing") * (
include_past_forcing + include_future_forcing + 1
num_past_forcing_steps + num_future_forcing_steps + 1
)

# batch times
Expand All @@ -85,14 +85,14 @@ def test_dataset_item_create_dataarray_from_tensor(datastore_name):
datastore = init_datastore_example(datastore_name)

N_pred_steps = 4
include_past_forcing = 1
include_future_forcing = 1
num_past_forcing_steps = 1
num_future_forcing_steps = 1
dataset = WeatherDataset(
datastore=datastore,
split="train",
ar_steps=N_pred_steps,
include_past_forcing=include_past_forcing,
include_future_forcing=include_future_forcing,
num_past_forcing_steps=num_past_forcing_steps,
num_future_forcing_steps=num_future_forcing_steps,
)

idx = 0
Expand Down Expand Up @@ -184,8 +184,8 @@ class ModelArgs:
hidden_layers = 1
processor_layers = 4
mesh_aggr = "sum"
include_past_forcing = 1
include_future_forcing = 1
num_past_forcing_steps = 1
num_future_forcing_steps = 1

args = ModelArgs()

Expand Down Expand Up @@ -248,8 +248,8 @@ def test_dataset_length(dataset_config):
datastore=datastore,
split="train",
ar_steps=dataset_config["ar_steps"],
include_past_forcing=dataset_config["past"],
include_future_forcing=dataset_config["future"],
num_past_forcing_steps=dataset_config["past"],
num_future_forcing_steps=dataset_config["future"],
)

# We expect dataset to contain this many samples
Expand Down
8 changes: 4 additions & 4 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def test_training(datastore_name):
standardize=True,
batch_size=2,
num_workers=1,
include_past_forcing=1,
include_future_forcing=1,
num_past_forcing_steps=1,
num_future_forcing_steps=1,
)

class ModelArgs:
Expand All @@ -83,8 +83,8 @@ class ModelArgs:
lr = 1.0e-3
val_steps_to_log = [1, 3]
metrics_watch = []
include_past_forcing = 1
include_future_forcing = 1
num_past_forcing_steps = 1
num_future_forcing_steps = 1

model_args = ModelArgs()

Expand Down

0 comments on commit b56e47a

Please sign in to comment.