diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 81e26d22..59ca1fdc 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -38,10 +38,14 @@ def __init__( da_state_stats = datastore.get_normalization_dataarray(category="state") da_boundary_mask = datastore.boundary_mask - # Load static features for grid/data + # Load static features for grid/data, NB: self.predict_step assumes dimension + # order to be (grid_index, static_feature) + arr_static = da_static_features.transpose( + "grid_index", "static_feature" + ).values self.register_buffer( "grid_static_features", - torch.tensor(da_static_features.values, dtype=torch.float32), + torch.tensor(arr_static, dtype=torch.float32), persistent=False, ) @@ -98,7 +102,10 @@ def __init__( boundary_mask = torch.tensor( da_boundary_mask.values, dtype=torch.float32 - ) + ).unsqueeze( + 1 + ) # add feature dim + self.register_buffer("boundary_mask", boundary_mask, persistent=False) # Pre-compute interior mask for use in loss function self.register_buffer( diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 4175e2d1..a76fc518 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -105,11 +105,6 @@ def predict_step(self, prev_state, prev_prev_state, forcing): """ batch_size = prev_state.shape[0] - print(f"prev_state.shape: {prev_state.shape}") - print(f"prev_prev_state.shape: {prev_prev_state.shape}") - print(f"forcing.shape: {forcing.shape}") - print(f"grid_static_features.shape: {self.grid_static_features.shape}") - # Create full grid node features of shape (B, num_grid_nodes, grid_dim) grid_features = torch.cat( ( diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 4e38dbd5..de5067b3 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -23,13 +23,11 @@ def __init__( split="train", ar_steps=3, forcing_window_size=3, - batch_size=4, standardize=True, ): super().__init__() self.split = split - self.batch_size = batch_size self.ar_steps = ar_steps self.datastore = datastore diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 40ca7398..f6802f5b 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,8 +1,15 @@ +# Standard library +from pathlib import Path + # Third-party import pytest +import torch from test_datastores import DATASTORES, init_datastore +from torch.utils.data import DataLoader # First-party +from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataset @@ -47,9 +54,8 @@ def test_dataset_item(datastore_name): assert target_states.shape[2] == datastore.get_num_data_vars("state") # forcing - assert forcing.shape[0] == N_pred_steps # number of prediction steps - assert forcing.shape[1] == N_gridpoints # number of grid points - # number of features x window size + assert forcing.shape[0] == N_pred_steps + assert forcing.shape[1] == N_gridpoints assert ( forcing.shape[2] == datastore.get_num_data_vars("forcing") * forcing_window_size @@ -57,3 +63,57 @@ def test_dataset_item(datastore_name): # batch times assert batch_times.shape[0] == N_pred_steps + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_single_batch(datastore_name, split="train"): + """Check that the `datastore.get_dataarray` method is implemented. + + And that it returns an xarray DataArray with the correct dimensions. + """ + datastore = init_datastore(datastore_name) + + device_name = ( # noqa + torch.device("cuda") if torch.cuda.is_available() else "cpu" + ) + + graph_name = "1level" + + class ModelArgs: + output_std = False + loss = "mse" + restore_opt = False + n_example_pred = 1 + # XXX: this should be superfluous when we have already defined the + # model object no? + graph = graph_name + hidden_dim = 64 + hidden_layers = 1 + processor_layers = 4 + mesh_aggr = "sum" + + args = ModelArgs() + + graph_dir_path = Path(datastore.root_path) / "graph" / graph_name + + if not graph_dir_path.exists(): + create_graph_from_datastore( + datastore=datastore, + output_root_path=str(graph_dir_path), + n_max_levels=1, + ) + + dataset = WeatherDataset(datastore=datastore, split=split) + + model = GraphLAM( # noqa + args=args, + forcing_window_size=dataset.forcing_window_size, + datastore=datastore, + ) + + model_device = model.to(device_name) + data_loader = DataLoader(dataset, batch_size=2) + batch = next(iter(data_loader)) + model_device.common_step(batch) + + assert False