Skip to content

Commit

Permalink
Wrap up first version of new graph tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Dec 7, 2024
1 parent a1f0f62 commit 797b867
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 81 deletions.
71 changes: 71 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Third-party
import numpy as np
import pooch
import torch
import yaml

# First-party
Expand Down Expand Up @@ -135,3 +136,73 @@ def get_test_mesh_dist(datastore, datastore_boundary):

# Want at least 10 mesh nodes in each direction
return min_extent / 10.0


def check_saved_graph(graph_dir_path, hierarchical):
"""Perform all checking for a saved graph"""
required_graph_files = [
"m2m_edge_index.pt",
"g2m_edge_index.pt",
"m2g_edge_index.pt",
"m2m_features.pt",
"g2m_features.pt",
"m2g_features.pt",
"m2m_node_features.pt",
]

if hierarchical:
required_graph_files.extend(
[
"mesh_up_edge_index.pt",
"mesh_down_edge_index.pt",
"mesh_up_features.pt",
"mesh_down_features.pt",
]
)
num_levels = 3

# TODO: check that the number of edges is consistent over the files, for
# now we just check the number of features
d_features = 3
d_mesh_static = 2

assert graph_dir_path.exists()

# check that all the required files are present
for file_name in required_graph_files:
assert (graph_dir_path / file_name).exists()

# try to load each and ensure they have the right shape
for file_name in required_graph_files:
file_id = Path(file_name).stem # remove the extension
result = torch.load(graph_dir_path / file_name)

if file_id.startswith("g2m") or file_id.startswith("m2g"):
assert isinstance(result, torch.Tensor)

if file_id.endswith("_index"):
assert result.shape[0] == 2 # adjacency matrix uses two rows
elif file_id.endswith("_features"):
assert result.shape[1] == d_features

elif file_id.startswith("m2m") or file_id.startswith("mesh"):
assert isinstance(result, list)
if not hierarchical:
assert len(result) == 1
else:
if file_id.startswith("mesh_up") or file_id.startswith(
"mesh_down"
):
assert len(result) == num_levels - 1
else:
assert len(result) == num_levels

for r in result:
assert isinstance(r, torch.Tensor)

if file_id == "m2m_node_features":
assert r.shape[1] == d_mesh_static
elif file_id.endswith("_index"):
assert r.shape[0] == 2 # adjacency matrix uses two rows
elif file_id.endswith("_features"):
assert r.shape[1] == d_features
18 changes: 13 additions & 5 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,31 +209,34 @@ def test_single_batch(datastore_name, datastore_boundary_name, split):
torch.device("cuda") if torch.cuda.is_available() else "cpu"
) # noqa

graph_name = "1level"
flat_graph_name = "1level"

class ModelArgs:
output_std = False
loss = "mse"
restore_opt = False
n_example_pred = 1
graph = graph_name
graph_name = flat_graph_name
hidden_dim = 4
hidden_layers = 1
processor_layers = 2
mesh_aggr = "sum"
num_past_forcing_steps = 1
num_future_forcing_steps = 1
num_past_boundary_steps = 1
num_future_boundary_steps = 1
shared_grid_embedder = False

args = ModelArgs()

graph_dir_path = Path(datastore.root_path) / "graphs" / graph_name
graph_dir_path = Path(datastore.root_path) / "graphs" / flat_graph_name

def _create_graph():
if not graph_dir_path.exists():
build_graph_from_archetype(
datastore=datastore,
datastore_boundary=datastore_boundary,
graph_name=graph_name,
graph_name=flat_graph_name,
archetype="keisler",
mesh_node_distance=get_test_mesh_dist(
datastore, datastore_boundary
Expand All @@ -257,7 +260,12 @@ def _create_graph():
datastore=datastore, datastore_boundary=datastore_boundary, split=split
)

model = GraphLAM(args=args, datastore=datastore, config=config) # noqa
model = GraphLAM(
args=args,
datastore=datastore,
datastore_boundary=datastore_boundary,
config=config,
) # noqa

model_device = model.to(device_name)
data_loader = DataLoader(dataset, batch_size=2)
Expand Down
156 changes: 86 additions & 70 deletions tests/test_graph_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# Third-party
import pytest
import torch

# First-party
from neural_lam.build_rectangular_graph import (
Expand All @@ -13,6 +12,7 @@
from neural_lam.datastore import DATASTORES
from tests.conftest import (
DATASTORES_BOUNDARY_EXAMPLES,
check_saved_graph,
get_test_mesh_dist,
init_datastore_boundary_example,
init_datastore_example,
Expand All @@ -25,8 +25,8 @@
list(DATASTORES_BOUNDARY_EXAMPLES.keys()) + [None],
)
@pytest.mark.parametrize("archetype", ["keisler", "graphcast", "hierarchical"])
def test_graph_creation(datastore_name, datastore_boundary_name, archetype):
"""Check that the `create_ graph_from_datastore` function is implemented.
def test_build_archetype(datastore_name, datastore_boundary_name, archetype):
"""Check that the `build_graph_from_archetype` function is implemented.
And that the graph is created in the correct location.
"""
Expand All @@ -50,86 +50,102 @@ def test_graph_creation(datastore_name, datastore_boundary_name, archetype):
# Add additional multi-level kwargs
create_kwargs.update(
{
"level_refinement_factor": 3,
"level_refinement_factor": 2,
"max_num_levels": num_levels,
}
)

required_graph_files = [
"m2m_edge_index.pt",
"g2m_edge_index.pt",
"m2g_edge_index.pt",
"m2m_features.pt",
"g2m_features.pt",
"m2g_features.pt",
"m2m_node_features.pt",
]
# Name graph
graph_name = f"{datastore_name}_{datastore_boundary_name}_{archetype}"

# Saved in datastore
# TODO: Maybe save in tmp dir?
graph_dir_path = Path(datastore.root_path) / "graphs" / graph_name

build_graph_from_archetype(
datastore, datastore_boundary, graph_name, archetype, **create_kwargs
)

hierarchical = archetype == "hierarchical"
if hierarchical:
required_graph_files.extend(
[
"mesh_up_edge_index.pt",
"mesh_down_edge_index.pt",
"mesh_up_features.pt",
"mesh_down_features.pt",
]
check_saved_graph(graph_dir_path, hierarchical)


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
@pytest.mark.parametrize(
"datastore_boundary_name",
list(DATASTORES_BOUNDARY_EXAMPLES.keys()) + [None],
)
@pytest.mark.parametrize(
"config_i, graph_kwargs",
enumerate(
[
# Assortment of options
{
"m2m_connectivity": "flat",
"m2g_connectivity": "nearest_neighbour",
"g2m_connectivity": "nearest_neighbour",
"m2m_connectivity_kwargs": {},
},
{
"m2m_connectivity": "flat_multiscale",
"m2g_connectivity": "nearest_neighbours",
"g2m_connectivity": "within_radius",
"m2m_connectivity_kwargs": {
"level_refinement_factor": 2,
},
"m2g_connectivity_kwargs": {
"max_num_neighbours": 4,
},
"g2m_connectivity_kwargs": {
"rel_max_dist": 0.3,
},
},
{
"m2m_connectivity": "hierarchical",
"m2g_connectivity": "containing_rectangle",
"g2m_connectivity": "within_radius",
"m2m_connectivity_kwargs": {
"level_refinement_factor": 2,
},
"m2g_connectivity_kwargs": {},
"g2m_connectivity_kwargs": {
"rel_max_dist": 0.51,
},
},
]
),
)
def test_build_from_options(
datastore_name, datastore_boundary_name, config_i, graph_kwargs
):
"""Check that the `build_graph_from_archetype` function is implemented.
And that the graph is created in the correct location.
"""
datastore = init_datastore_example(datastore_name)

if datastore_boundary_name is None:
# LAM scale
datastore_boundary = None
else:
# Global scale, ERA5 coords flattened with proj
datastore_boundary = init_datastore_boundary_example(
datastore_boundary_name
)
num_levels = 3

# TODO: check that the number of edges is consistent over the files, for
# now we just check the number of features
d_features = 3
d_mesh_static = 2
# Insert mesh distance
graph_kwargs["m2m_connectivity_kwargs"][
"mesh_node_distance"
] = get_test_mesh_dist(datastore, datastore_boundary)

# Name graph
graph_name = f"{datastore_name}_{datastore_boundary_name}_{archetype}"
graph_name = f"{datastore_name}_{datastore_boundary_name}_config{config_i}"

# Saved in datastore
# TODO: Maybe save in tmp dir?
graph_dir_path = Path(datastore.root_path) / "graphs" / graph_name

build_graph_from_archetype(
datastore, datastore_boundary, graph_name, archetype, **create_kwargs
)
build_graph(datastore, datastore_boundary, graph_name, **graph_kwargs)

assert graph_dir_path.exists()

# check that all the required files are present
for file_name in required_graph_files:
assert (graph_dir_path / file_name).exists()

# try to load each and ensure they have the right shape
for file_name in required_graph_files:
file_id = Path(file_name).stem # remove the extension
result = torch.load(graph_dir_path / file_name)

if file_id.startswith("g2m") or file_id.startswith("m2g"):
assert isinstance(result, torch.Tensor)

if file_id.endswith("_index"):
assert result.shape[0] == 2 # adjacency matrix uses two rows
elif file_id.endswith("_features"):
assert result.shape[1] == d_features

elif file_id.startswith("m2m") or file_id.startswith("mesh"):
assert isinstance(result, list)
if not hierarchical:
assert len(result) == 1
else:
if file_id.startswith("mesh_up") or file_id.startswith(
"mesh_down"
):
assert len(result) == num_levels - 1
else:
assert len(result) == num_levels

for r in result:
assert isinstance(r, torch.Tensor)

if file_id == "m2m_node_features":
assert r.shape[1] == d_mesh_static
elif file_id.endswith("_index"):
assert r.shape[0] == 2 # adjacency matrix uses two rows
elif file_id.endswith("_features"):
assert r.shape[1] == d_features
hierarchical = graph_kwargs["m2m_connectivity"] == "hierarchical"
check_saved_graph(graph_dir_path, hierarchical)
15 changes: 9 additions & 6 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,16 @@ def test_training(datastore_name, datastore_boundary_name):
log_every_n_steps=1,
)

graph_name = "1level"
flat_graph_name = "1level"

graph_dir_path = Path(datastore.root_path) / "graphs" / graph_name
graph_dir_path = Path(datastore.root_path) / "graphs" / flat_graph_name

def _create_graph():
if not graph_dir_path.exists():
build_graph_from_archetype(
datastore=datastore,
datastore_boundary=datastore_boundary,
graph_name=graph_name,
graph_name=flat_graph_name,
archetype="keisler",
mesh_node_distance=get_test_mesh_dist(
datastore, datastore_boundary
Expand Down Expand Up @@ -99,7 +99,7 @@ class ModelArgs:
n_example_pred = 1
# XXX: this should be superfluous when we have already defined the
# model object no?
graph = graph_name
graph_name = flat_graph_name
hidden_dim = 4
hidden_layers = 1
processor_layers = 2
Expand All @@ -111,6 +111,7 @@ class ModelArgs:
num_future_forcing_steps = 1
num_past_boundary_steps = 1
num_future_boundary_steps = 1
shared_grid_embedder = False

model_args = ModelArgs()

Expand All @@ -120,10 +121,12 @@ class ModelArgs:
)
)

model = GraphLAM( # noqa
model = GraphLAM(
args=model_args,
datastore=datastore,
datastore_boundary=datastore_boundary,
config=config,
)
) # noqa

wandb.init()
trainer.fit(model=model, datamodule=data_module)

0 comments on commit 797b867

Please sign in to comment.