Skip to content

Commit

Permalink
test for single batch from mllam through model
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Jul 24, 2024
1 parent 8698719 commit 3381404
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 13 deletions.
13 changes: 10 additions & 3 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ def __init__(
da_state_stats = datastore.get_normalization_dataarray(category="state")
da_boundary_mask = datastore.boundary_mask

# Load static features for grid/data
# Load static features for grid/data, NB: self.predict_step assumes dimension
# order to be (grid_index, static_feature)
arr_static = da_static_features.transpose(
"grid_index", "static_feature"
).values
self.register_buffer(
"grid_static_features",
torch.tensor(da_static_features.values, dtype=torch.float32),
torch.tensor(arr_static, dtype=torch.float32),
persistent=False,
)

Expand Down Expand Up @@ -98,7 +102,10 @@ def __init__(

boundary_mask = torch.tensor(
da_boundary_mask.values, dtype=torch.float32
)
).unsqueeze(
1
) # add feature dim

self.register_buffer("boundary_mask", boundary_mask, persistent=False)
# Pre-compute interior mask for use in loss function
self.register_buffer(
Expand Down
5 changes: 0 additions & 5 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,6 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
"""
batch_size = prev_state.shape[0]

print(f"prev_state.shape: {prev_state.shape}")
print(f"prev_prev_state.shape: {prev_prev_state.shape}")
print(f"forcing.shape: {forcing.shape}")
print(f"grid_static_features.shape: {self.grid_static_features.shape}")

# Create full grid node features of shape (B, num_grid_nodes, grid_dim)
grid_features = torch.cat(
(
Expand Down
2 changes: 0 additions & 2 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ def __init__(
split="train",
ar_steps=3,
forcing_window_size=3,
batch_size=4,
standardize=True,
):
super().__init__()

self.split = split
self.batch_size = batch_size
self.ar_steps = ar_steps
self.datastore = datastore

Expand Down
66 changes: 63 additions & 3 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# Standard library
from pathlib import Path

# Third-party
import pytest
import torch
from test_datastores import DATASTORES, init_datastore
from torch.utils.data import DataLoader

# First-party
from neural_lam.create_graph import create_graph_from_datastore
from neural_lam.models.graph_lam import GraphLAM
from neural_lam.weather_dataset import WeatherDataset


Expand Down Expand Up @@ -47,13 +54,66 @@ def test_dataset_item(datastore_name):
assert target_states.shape[2] == datastore.get_num_data_vars("state")

# forcing
assert forcing.shape[0] == N_pred_steps # number of prediction steps
assert forcing.shape[1] == N_gridpoints # number of grid points
# number of features x window size
assert forcing.shape[0] == N_pred_steps
assert forcing.shape[1] == N_gridpoints
assert (
forcing.shape[2]
== datastore.get_num_data_vars("forcing") * forcing_window_size
)

# batch times
assert batch_times.shape[0] == N_pred_steps


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_single_batch(datastore_name, split="train"):
"""Check that the `datastore.get_dataarray` method is implemented.
And that it returns an xarray DataArray with the correct dimensions.
"""
datastore = init_datastore(datastore_name)

device_name = ( # noqa
torch.device("cuda") if torch.cuda.is_available() else "cpu"
)

graph_name = "1level"

class ModelArgs:
output_std = False
loss = "mse"
restore_opt = False
n_example_pred = 1
# XXX: this should be superfluous when we have already defined the
# model object no?
graph = graph_name
hidden_dim = 64
hidden_layers = 1
processor_layers = 4
mesh_aggr = "sum"

args = ModelArgs()

graph_dir_path = Path(datastore.root_path) / "graph" / graph_name

if not graph_dir_path.exists():
create_graph_from_datastore(
datastore=datastore,
output_root_path=str(graph_dir_path),
n_max_levels=1,
)

dataset = WeatherDataset(datastore=datastore, split=split)

model = GraphLAM( # noqa
args=args,
forcing_window_size=dataset.forcing_window_size,
datastore=datastore,
)

model_device = model.to(device_name)
data_loader = DataLoader(dataset, batch_size=2)
batch = next(iter(data_loader))
model_device.common_step(batch)

assert False

0 comments on commit 3381404

Please sign in to comment.