From 797b86774e35ab488c8d3631693098b2d969b392 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Sat, 7 Dec 2024 10:28:46 +0100 Subject: [PATCH] Wrap up first version of new graph tests --- tests/conftest.py | 71 ++++++++++++++++ tests/test_datasets.py | 18 ++-- tests/test_graph_creation.py | 156 +++++++++++++++++++---------------- tests/test_training.py | 15 ++-- 4 files changed, 179 insertions(+), 81 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index c2dc214..7dfb8ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ # Third-party import numpy as np import pooch +import torch import yaml # First-party @@ -135,3 +136,73 @@ def get_test_mesh_dist(datastore, datastore_boundary): # Want at least 10 mesh nodes in each direction return min_extent / 10.0 + + +def check_saved_graph(graph_dir_path, hierarchical): + """Perform all checking for a saved graph""" + required_graph_files = [ + "m2m_edge_index.pt", + "g2m_edge_index.pt", + "m2g_edge_index.pt", + "m2m_features.pt", + "g2m_features.pt", + "m2g_features.pt", + "m2m_node_features.pt", + ] + + if hierarchical: + required_graph_files.extend( + [ + "mesh_up_edge_index.pt", + "mesh_down_edge_index.pt", + "mesh_up_features.pt", + "mesh_down_features.pt", + ] + ) + num_levels = 3 + + # TODO: check that the number of edges is consistent over the files, for + # now we just check the number of features + d_features = 3 + d_mesh_static = 2 + + assert graph_dir_path.exists() + + # check that all the required files are present + for file_name in required_graph_files: + assert (graph_dir_path / file_name).exists() + + # try to load each and ensure they have the right shape + for file_name in required_graph_files: + file_id = Path(file_name).stem # remove the extension + result = torch.load(graph_dir_path / file_name) + + if file_id.startswith("g2m") or file_id.startswith("m2g"): + assert isinstance(result, torch.Tensor) + + if file_id.endswith("_index"): + assert result.shape[0] == 2 # adjacency matrix uses two rows + elif file_id.endswith("_features"): + assert result.shape[1] == d_features + + elif file_id.startswith("m2m") or file_id.startswith("mesh"): + assert isinstance(result, list) + if not hierarchical: + assert len(result) == 1 + else: + if file_id.startswith("mesh_up") or file_id.startswith( + "mesh_down" + ): + assert len(result) == num_levels - 1 + else: + assert len(result) == num_levels + + for r in result: + assert isinstance(r, torch.Tensor) + + if file_id == "m2m_node_features": + assert r.shape[1] == d_mesh_static + elif file_id.endswith("_index"): + assert r.shape[0] == 2 # adjacency matrix uses two rows + elif file_id.endswith("_features"): + assert r.shape[1] == d_features diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 748abe8..5735f52 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -209,31 +209,34 @@ def test_single_batch(datastore_name, datastore_boundary_name, split): torch.device("cuda") if torch.cuda.is_available() else "cpu" ) # noqa - graph_name = "1level" + flat_graph_name = "1level" class ModelArgs: output_std = False loss = "mse" restore_opt = False n_example_pred = 1 - graph = graph_name + graph_name = flat_graph_name hidden_dim = 4 hidden_layers = 1 processor_layers = 2 mesh_aggr = "sum" num_past_forcing_steps = 1 num_future_forcing_steps = 1 + num_past_boundary_steps = 1 + num_future_boundary_steps = 1 + shared_grid_embedder = False args = ModelArgs() - graph_dir_path = Path(datastore.root_path) / "graphs" / graph_name + graph_dir_path = Path(datastore.root_path) / "graphs" / flat_graph_name def _create_graph(): if not graph_dir_path.exists(): build_graph_from_archetype( datastore=datastore, datastore_boundary=datastore_boundary, - graph_name=graph_name, + graph_name=flat_graph_name, archetype="keisler", mesh_node_distance=get_test_mesh_dist( datastore, datastore_boundary @@ -257,7 +260,12 @@ def _create_graph(): datastore=datastore, datastore_boundary=datastore_boundary, split=split ) - model = GraphLAM(args=args, datastore=datastore, config=config) # noqa + model = GraphLAM( + args=args, + datastore=datastore, + datastore_boundary=datastore_boundary, + config=config, + ) # noqa model_device = model.to(device_name) data_loader = DataLoader(dataset, batch_size=2) diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index 3d6e079..8c2fa60 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -3,7 +3,6 @@ # Third-party import pytest -import torch # First-party from neural_lam.build_rectangular_graph import ( @@ -13,6 +12,7 @@ from neural_lam.datastore import DATASTORES from tests.conftest import ( DATASTORES_BOUNDARY_EXAMPLES, + check_saved_graph, get_test_mesh_dist, init_datastore_boundary_example, init_datastore_example, @@ -25,8 +25,8 @@ list(DATASTORES_BOUNDARY_EXAMPLES.keys()) + [None], ) @pytest.mark.parametrize("archetype", ["keisler", "graphcast", "hierarchical"]) -def test_graph_creation(datastore_name, datastore_boundary_name, archetype): - """Check that the `create_ graph_from_datastore` function is implemented. +def test_build_archetype(datastore_name, datastore_boundary_name, archetype): + """Check that the `build_graph_from_archetype` function is implemented. And that the graph is created in the correct location. """ @@ -50,86 +50,102 @@ def test_graph_creation(datastore_name, datastore_boundary_name, archetype): # Add additional multi-level kwargs create_kwargs.update( { - "level_refinement_factor": 3, + "level_refinement_factor": 2, "max_num_levels": num_levels, } ) - required_graph_files = [ - "m2m_edge_index.pt", - "g2m_edge_index.pt", - "m2g_edge_index.pt", - "m2m_features.pt", - "g2m_features.pt", - "m2g_features.pt", - "m2m_node_features.pt", - ] + # Name graph + graph_name = f"{datastore_name}_{datastore_boundary_name}_{archetype}" + + # Saved in datastore + # TODO: Maybe save in tmp dir? + graph_dir_path = Path(datastore.root_path) / "graphs" / graph_name + + build_graph_from_archetype( + datastore, datastore_boundary, graph_name, archetype, **create_kwargs + ) hierarchical = archetype == "hierarchical" - if hierarchical: - required_graph_files.extend( - [ - "mesh_up_edge_index.pt", - "mesh_down_edge_index.pt", - "mesh_up_features.pt", - "mesh_down_features.pt", - ] + check_saved_graph(graph_dir_path, hierarchical) + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +@pytest.mark.parametrize( + "datastore_boundary_name", + list(DATASTORES_BOUNDARY_EXAMPLES.keys()) + [None], +) +@pytest.mark.parametrize( + "config_i, graph_kwargs", + enumerate( + [ + # Assortment of options + { + "m2m_connectivity": "flat", + "m2g_connectivity": "nearest_neighbour", + "g2m_connectivity": "nearest_neighbour", + "m2m_connectivity_kwargs": {}, + }, + { + "m2m_connectivity": "flat_multiscale", + "m2g_connectivity": "nearest_neighbours", + "g2m_connectivity": "within_radius", + "m2m_connectivity_kwargs": { + "level_refinement_factor": 2, + }, + "m2g_connectivity_kwargs": { + "max_num_neighbours": 4, + }, + "g2m_connectivity_kwargs": { + "rel_max_dist": 0.3, + }, + }, + { + "m2m_connectivity": "hierarchical", + "m2g_connectivity": "containing_rectangle", + "g2m_connectivity": "within_radius", + "m2m_connectivity_kwargs": { + "level_refinement_factor": 2, + }, + "m2g_connectivity_kwargs": {}, + "g2m_connectivity_kwargs": { + "rel_max_dist": 0.51, + }, + }, + ] + ), +) +def test_build_from_options( + datastore_name, datastore_boundary_name, config_i, graph_kwargs +): + """Check that the `build_graph_from_archetype` function is implemented. + And that the graph is created in the correct location. + + """ + datastore = init_datastore_example(datastore_name) + + if datastore_boundary_name is None: + # LAM scale + datastore_boundary = None + else: + # Global scale, ERA5 coords flattened with proj + datastore_boundary = init_datastore_boundary_example( + datastore_boundary_name ) - num_levels = 3 - # TODO: check that the number of edges is consistent over the files, for - # now we just check the number of features - d_features = 3 - d_mesh_static = 2 + # Insert mesh distance + graph_kwargs["m2m_connectivity_kwargs"][ + "mesh_node_distance" + ] = get_test_mesh_dist(datastore, datastore_boundary) # Name graph - graph_name = f"{datastore_name}_{datastore_boundary_name}_{archetype}" + graph_name = f"{datastore_name}_{datastore_boundary_name}_config{config_i}" # Saved in datastore # TODO: Maybe save in tmp dir? graph_dir_path = Path(datastore.root_path) / "graphs" / graph_name - build_graph_from_archetype( - datastore, datastore_boundary, graph_name, archetype, **create_kwargs - ) + build_graph(datastore, datastore_boundary, graph_name, **graph_kwargs) - assert graph_dir_path.exists() - - # check that all the required files are present - for file_name in required_graph_files: - assert (graph_dir_path / file_name).exists() - - # try to load each and ensure they have the right shape - for file_name in required_graph_files: - file_id = Path(file_name).stem # remove the extension - result = torch.load(graph_dir_path / file_name) - - if file_id.startswith("g2m") or file_id.startswith("m2g"): - assert isinstance(result, torch.Tensor) - - if file_id.endswith("_index"): - assert result.shape[0] == 2 # adjacency matrix uses two rows - elif file_id.endswith("_features"): - assert result.shape[1] == d_features - - elif file_id.startswith("m2m") or file_id.startswith("mesh"): - assert isinstance(result, list) - if not hierarchical: - assert len(result) == 1 - else: - if file_id.startswith("mesh_up") or file_id.startswith( - "mesh_down" - ): - assert len(result) == num_levels - 1 - else: - assert len(result) == num_levels - - for r in result: - assert isinstance(r, torch.Tensor) - - if file_id == "m2m_node_features": - assert r.shape[1] == d_mesh_static - elif file_id.endswith("_index"): - assert r.shape[0] == 2 # adjacency matrix uses two rows - elif file_id.endswith("_features"): - assert r.shape[1] == d_features + hierarchical = graph_kwargs["m2m_connectivity"] == "hierarchical" + check_saved_graph(graph_dir_path, hierarchical) diff --git a/tests/test_training.py b/tests/test_training.py index 5f2d43d..4d07b08 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -62,16 +62,16 @@ def test_training(datastore_name, datastore_boundary_name): log_every_n_steps=1, ) - graph_name = "1level" + flat_graph_name = "1level" - graph_dir_path = Path(datastore.root_path) / "graphs" / graph_name + graph_dir_path = Path(datastore.root_path) / "graphs" / flat_graph_name def _create_graph(): if not graph_dir_path.exists(): build_graph_from_archetype( datastore=datastore, datastore_boundary=datastore_boundary, - graph_name=graph_name, + graph_name=flat_graph_name, archetype="keisler", mesh_node_distance=get_test_mesh_dist( datastore, datastore_boundary @@ -99,7 +99,7 @@ class ModelArgs: n_example_pred = 1 # XXX: this should be superfluous when we have already defined the # model object no? - graph = graph_name + graph_name = flat_graph_name hidden_dim = 4 hidden_layers = 1 processor_layers = 2 @@ -111,6 +111,7 @@ class ModelArgs: num_future_forcing_steps = 1 num_past_boundary_steps = 1 num_future_boundary_steps = 1 + shared_grid_embedder = False model_args = ModelArgs() @@ -120,10 +121,12 @@ class ModelArgs: ) ) - model = GraphLAM( # noqa + model = GraphLAM( args=model_args, datastore=datastore, + datastore_boundary=datastore_boundary, config=config, - ) + ) # noqa + wandb.init() trainer.fit(model=model, datamodule=data_module)