From dddc7a773525fa2e29d0ea17eeeef6361b007617 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Wed, 12 Jun 2024 17:42:10 +0200 Subject: [PATCH 1/8] Start implementing graph classes for graph models --- .gitignore | 2 +- neural_lam/graphs/base_weather_graph.py | 134 ++++++++++++++++++++++++ neural_lam/graphs/flat_weather_graph.py | 76 ++++++++++++++ 3 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 neural_lam/graphs/base_weather_graph.py create mode 100644 neural_lam/graphs/flat_weather_graph.py diff --git a/.gitignore b/.gitignore index 65e9f6f8..185d25e8 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ slurm_log* saved_models lightning_logs data -graphs +graphs/* *.sif sweeps test_*.sh diff --git a/neural_lam/graphs/base_weather_graph.py b/neural_lam/graphs/base_weather_graph.py new file mode 100644 index 00000000..afc92b71 --- /dev/null +++ b/neural_lam/graphs/base_weather_graph.py @@ -0,0 +1,134 @@ +# Standard library +import os + +# Third-party +import torch +import torch.nn as nn + + +class BaseWeatherGraph(nn.Module): + """ + Graph object representing weather graph consisting of grid and mesh nodes + """ + + def __init__( + self, + g2m_edge_index, + g2m_edge_features, + m2g_edge_index, + m2g_edge_features, + ): + """ + Create a new graph from tensors + """ + super().__init__() + + # Store edge indices + self.g2m_edge_index = g2m_edge_index + self.m2g_edge_index = m2g_edge_index + + # Store edge features + self.g2m_edge_features = g2m_edge_features + self.m2g_edge_features = m2g_edge_features + + # TODO Checks that node indices align + # TODO Make all node indices start at 0 + + def num_mesh_nodes(self): + # TODO use g2m + pass + + @staticmethod + def from_graph_dir(path): + """ + Create WeatherGraph from tensors stored in a graph directory + + path: str, path to directory where graph data is stored + """ + ( + g2m_edge_index, + g2m_edge_features, + ) = BaseWeatherGraph._load_subgraph_from_dir( + path, "g2m" + ) # (2, M_g2m), (M_g2m, d_edge_features) + ( + m2g_edge_index, + m2g_edge_features, + ) = BaseWeatherGraph._load_subgraph_from_dir( + path, "m2g" + ) # (2, M_m2g), (M_m2g, d_edge_features) + + return BaseWeatherGraph( + g2m_edge_index, + g2m_edge_features, + m2g_edge_index, + m2g_edge_features, + ) + + @staticmethod + def _load_subgraph_from_dir(graph_dir_path, subgraph_name): + """ + Load edge_index + feature tensor from a graph directory, + for a specific subgraph + """ + edge_index = BaseWeatherGraph._load_graph_tensor( + graph_dir_path, f"{subgraph_name}_edge_index.pt" + ) + + # Check edge_index + assert edge_index.dtype == torch.int64, ( + f"Wrong data type for {subgraph_name} edge_index " + f"tensor: {edge_index.dtype}" + ) + assert len(edge_index.shape) == 2, ( + f"Wrong shape of {subgraph_name} edge_index tensor: " + f"{edge_index.shape}" + ) + assert edge_index.shape[0] == 2, ( + "Wrong shape of {subgraph_name} edge_index tensor: " + f"{edge_index.shape}" + ) + + edge_features = BaseWeatherGraph._load_feature_tensor( + graph_dir_path, f"{subgraph_name}_features.pt" + ) + + # Check compatibility + assert edge_features.shape[0] == edge_index.shape[1], ( + f"Mismatch in shape of {subgraph_name} edge_index " + f"(edge_index.shape) and features {edge_features.shape}" + ) + + return edge_index, edge_features + + @staticmethod + def _load_feature_tensor(graph_dir_path, file_name): + """ + Load feature tensor with from a graph directory + """ + features = BaseWeatherGraph._load_graph_tensor( + graph_dir_path, file_name + ) + + # Check features + assert features.dtype == torch.float32, ( + f"Wrong data type for {file_name} graph feature tensor: " + f"{features.dtype}" + ) + assert len(features.shape) == 2, ( + f"Wrong shape of {file_name} graph feature tensor: " + f"{features.shape}" + ) + + return features + + @staticmethod + def _load_graph_tensor(graph_dir_path, file_name): + """ + Load graph tensor with edge_index or features from a graph directory + """ + return torch.load(os.path.join(graph_dir_path, file_name)) + + def __str__(): + # TODO Get from graph model init functions + pass diff --git a/neural_lam/graphs/flat_weather_graph.py b/neural_lam/graphs/flat_weather_graph.py new file mode 100644 index 00000000..4d6f43b6 --- /dev/null +++ b/neural_lam/graphs/flat_weather_graph.py @@ -0,0 +1,76 @@ +# First-party +from neural_lam.graphs.base_weather_graph import BaseWeatherGraph + + +class FlatWeatherGraph(BaseWeatherGraph): + """ + Graph object representing weather graph consisting of grid and mesh nodes + """ + + def __init__( + self, + g2m_edge_index, + g2m_edge_features, + m2g_edge_index, + m2g_edge_features, + m2m_edge_index, + m2m_edge_features, + mesh_node_features, + ): + """ + Create a new graph from tensors + """ + super().__init__( + g2m_edge_index, + g2m_edge_features, + m2g_edge_index, + m2g_edge_features, + ) + + # Store mesh tensors + self.m2m_edge_index = m2m_edge_index + self.m2m_edge_features = m2m_edge_features + self.mesh_node_features = mesh_node_features + + # TODO Checks that node indices align + + def num_mesh_nodes(self): + # TODO use mesh_node_features + pass + + @staticmethod + def from_graph_dir(path): + """ + Create WeatherGraph from tensors stored in a graph directory + + path: str, path to directory where graph data is stored + """ + # Load base grpah (g2m and m2g) + base_graph = BaseWeatherGraph.from_graph_dir(path) + + # Load m2m + ( + m2m_edge_index, + m2m_edge_features, + ) = BaseWeatherGraph._load_subgraph_from_dir( + path, "m2m" + ) # (2, M_m2m), (M_m2m, d_edge_features) + + # Load static mesh node features + mesh_node_features = BaseWeatherGraph._load_feature_tensor( + path, "mesh_features.pt" + ) # (N_mesh, d_node_features) + + # Note: We assume that graph features are already normalized + # when read from disk + # TODO ^ actually do this in graph creation + + return FlatWeatherGraph( + base_graph.g2m_edge_index, + base_graph.g2m_edge_features, + base_graph.m2g_edge_index, + base_graph.m2g_edge_features, + m2m_edge_index, + m2m_edge_features, + mesh_node_features, + ) From 258b0137e86268ee7addafcd57ec7c64c6436d8a Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Wed, 26 Jun 2024 11:17:45 +0200 Subject: [PATCH 2/8] Turn graph classes into python dataclasses --- neural_lam/graphs/base_weather_graph.py | 123 +++++++++++++++--------- neural_lam/graphs/flat_weather_graph.py | 39 +++----- 2 files changed, 91 insertions(+), 71 deletions(-) diff --git a/neural_lam/graphs/base_weather_graph.py b/neural_lam/graphs/base_weather_graph.py index afc92b71..66a5ad30 100644 --- a/neural_lam/graphs/base_weather_graph.py +++ b/neural_lam/graphs/base_weather_graph.py @@ -1,38 +1,97 @@ # Standard library import os +from dataclasses import dataclass # Third-party import torch import torch.nn as nn +@dataclass class BaseWeatherGraph(nn.Module): """ Graph object representing weather graph consisting of grid and mesh nodes """ - def __init__( - self, - g2m_edge_index, - g2m_edge_features, - m2g_edge_index, - m2g_edge_features, - ): + g2m_edge_index: torch.Tensor + g2m_edge_features: torch.Tensor + m2g_edge_index: torch.Tensor + m2g_edge_features: torch.Tensor + + def __post_init__(self): + BaseWeatherGraph.check_subgraph( + self.g2m_edge_features, self.g2m_edge_index, "g2m" + ) + BaseWeatherGraph.check_subgraph( + self.m2g_edge_features, self.m2g_edge_index, "m2g" + ) + + # TODO Checks that node indices align + # TODO Make all node indices start at 0 + + @staticmethod + def check_features(features, subgraph_name): """ - Create a new graph from tensors + Check that feature tensor has the correct format + + features: (2, num_features) tensor of features + subgraph_name: name of associated subgraph, used in error messages """ - super().__init__() + assert isinstance( + features, torch.Tensor + ), f"{subgraph_name} features is not a tensor" + assert features.dtype == torch.float32, ( + f"Wrong data type for {subgraph_name} feature tensor: " + f"{features.dtype}" + ) + assert len(features.shape) == 2, ( + f"Wrong shape of {subgraph_name} feature tensor: " + f"{features.shape}" + ) - # Store edge indices - self.g2m_edge_index = g2m_edge_index - self.m2g_edge_index = m2g_edge_index + @staticmethod + def check_edge_index(edge_index, subgraph_name): + """ + Check that edge index tensor has the correct format - # Store edge features - self.g2m_edge_features = g2m_edge_features - self.m2g_edge_features = m2g_edge_features + edge_index: (2, num_edges) tensor with edge index + subgraph_name: name of associated subgraph, used in error messages + """ + assert isinstance( + edge_index, torch.Tensor + ), f"{subgraph_name} edge_index is not a tensor" + assert edge_index.dtype == torch.int64, ( + f"Wrong data type for {subgraph_name} edge_index " + f"tensor: {edge_index.dtype}" + ) + assert len(edge_index.shape) == 2, ( + f"Wrong shape of {subgraph_name} edge_index tensor: " + f"{edge_index.shape}" + ) + assert edge_index.shape[0] == 2, ( + "Wrong shape of {subgraph_name} edge_index tensor: " + f"{edge_index.shape}" + ) - # TODO Checks that node indices align - # TODO Make all node indices start at 0 + @staticmethod + def check_subgraph(edge_features, edge_index, subgraph_name): + """ + Check that tensors associated with subgraph (edge index and features) + has the correct format + + edge_features: (2, num_features) tensor of edge features + edge_index: (2, num_edges) tensor with edge index + subgraph_name: name of associated subgraph, used in error messages + """ + # Check individual tensors + BaseWeatherGraph.check_features(edge_features, subgraph_name) + BaseWeatherGraph.check_edge_index(edge_index, subgraph_name) + + # Check compatibility + assert edge_features.shape[0] == edge_index.shape[1], ( + f"Mismatch in shape of {subgraph_name} edge_index " + f"(edge_index.shape) and features {edge_features.shape}" + ) def num_mesh_nodes(self): # TODO use g2m @@ -75,30 +134,10 @@ def _load_subgraph_from_dir(graph_dir_path, subgraph_name): graph_dir_path, f"{subgraph_name}_edge_index.pt" ) - # Check edge_index - assert edge_index.dtype == torch.int64, ( - f"Wrong data type for {subgraph_name} edge_index " - f"tensor: {edge_index.dtype}" - ) - assert len(edge_index.shape) == 2, ( - f"Wrong shape of {subgraph_name} edge_index tensor: " - f"{edge_index.shape}" - ) - assert edge_index.shape[0] == 2, ( - "Wrong shape of {subgraph_name} edge_index tensor: " - f"{edge_index.shape}" - ) - edge_features = BaseWeatherGraph._load_feature_tensor( graph_dir_path, f"{subgraph_name}_features.pt" ) - # Check compatibility - assert edge_features.shape[0] == edge_index.shape[1], ( - f"Mismatch in shape of {subgraph_name} edge_index " - f"(edge_index.shape) and features {edge_features.shape}" - ) - return edge_index, edge_features @staticmethod @@ -110,16 +149,6 @@ def _load_feature_tensor(graph_dir_path, file_name): graph_dir_path, file_name ) - # Check features - assert features.dtype == torch.float32, ( - f"Wrong data type for {file_name} graph feature tensor: " - f"{features.dtype}" - ) - assert len(features.shape) == 2, ( - f"Wrong shape of {file_name} graph feature tensor: " - f"{features.shape}" - ) - return features @staticmethod diff --git a/neural_lam/graphs/flat_weather_graph.py b/neural_lam/graphs/flat_weather_graph.py index 4d6f43b6..933ec228 100644 --- a/neural_lam/graphs/flat_weather_graph.py +++ b/neural_lam/graphs/flat_weather_graph.py @@ -1,37 +1,28 @@ +# Standard library +from dataclasses import dataclass + +# Third-party +import torch + # First-party from neural_lam.graphs.base_weather_graph import BaseWeatherGraph +@dataclass class FlatWeatherGraph(BaseWeatherGraph): """ Graph object representing weather graph consisting of grid and mesh nodes """ - def __init__( - self, - g2m_edge_index, - g2m_edge_features, - m2g_edge_index, - m2g_edge_features, - m2m_edge_index, - m2m_edge_features, - mesh_node_features, - ): - """ - Create a new graph from tensors - """ - super().__init__( - g2m_edge_index, - g2m_edge_features, - m2g_edge_index, - m2g_edge_features, - ) - - # Store mesh tensors - self.m2m_edge_index = m2m_edge_index - self.m2m_edge_features = m2m_edge_features - self.mesh_node_features = mesh_node_features + m2m_edge_index: torch.Tensor + m2m_edge_features: torch.Tensor + mesh_node_features: torch.Tensor + def __post_init__(self): + BaseWeatherGraph.check_subgraph( + self.m2m_edge_features, self.m2m_edge_index, "m2m" + ) + BaseWeatherGraph.check_features(self.mesh_node_features, "mesh nodes") # TODO Checks that node indices align def num_mesh_nodes(self): From 183b24bbc66cf75d6a18e7eddcff8e410f61a252 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Wed, 26 Jun 2024 11:54:46 +0200 Subject: [PATCH 3/8] Add tests for flat graph class --- neural_lam/graphs/flat_weather_graph.py | 1 + tests/test_graph_classes.py | 100 ++++++++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 tests/test_graph_classes.py diff --git a/neural_lam/graphs/flat_weather_graph.py b/neural_lam/graphs/flat_weather_graph.py index 933ec228..623b199e 100644 --- a/neural_lam/graphs/flat_weather_graph.py +++ b/neural_lam/graphs/flat_weather_graph.py @@ -19,6 +19,7 @@ class FlatWeatherGraph(BaseWeatherGraph): mesh_node_features: torch.Tensor def __post_init__(self): + super().__post_init__() BaseWeatherGraph.check_subgraph( self.m2m_edge_features, self.m2m_edge_index, "m2m" ) diff --git a/tests/test_graph_classes.py b/tests/test_graph_classes.py new file mode 100644 index 00000000..bf2549d7 --- /dev/null +++ b/tests/test_graph_classes.py @@ -0,0 +1,100 @@ +# Standard library +import copy + +# Third-party +import pytest +import torch + +# First-party +from neural_lam.graphs.flat_weather_graph import FlatWeatherGraph + + +def create_dummy_graph_tensors(): + """ + Create dummy tensors for instantiating a flat graph + """ + num_grid = 10 + num_mesh = 5 + feature_dim = 3 + + return { + "g2m_edge_index": torch.zeros(2, num_grid, dtype=torch.long), + "g2m_edge_features": ( + torch.zeros(num_grid, feature_dim, dtype=torch.float32) + ), + "m2g_edge_index": torch.zeros(2, num_grid, dtype=torch.long), + "m2g_edge_features": ( + torch.zeros(num_grid, feature_dim, dtype=torch.float32) + ), + "m2m_edge_index": torch.zeros(2, num_mesh, dtype=torch.long), + "m2m_edge_features": ( + torch.zeros(num_mesh, feature_dim, dtype=torch.float32) + ), + "mesh_node_features": ( + torch.zeros(num_mesh, feature_dim, dtype=torch.float32) + ), + } + + +def test_create_flat_graph(): + """ + Test that a Flat weather graph can be created with correct tensors + """ + FlatWeatherGraph(**create_dummy_graph_tensors()) + + +@pytest.mark.parametrize( + "subgraph_name,tensor_name", + [ + (subgraph_name, tensor_name) + for subgraph_name in ("g2m", "m2g", "m2m") + for tensor_name in ("edge_features", "edge_index") + ] + + [("mesh", "node_features")], +) +def test_dtypes_flat_graph(subgraph_name, tensor_name): + """ + Test that wrong data types properly raises errors + """ + dummy_tensors = create_dummy_graph_tensors() + + # Test non-tensor input + dummy_copy = copy.copy(dummy_tensors) + dummy_copy[f"{subgraph_name}_{tensor_name}"] = 1 # Not a torch.Tensor + + with pytest.raises(AssertionError) as assertinfo: + FlatWeatherGraph(**dummy_copy) + assert subgraph_name in str( + assertinfo + ), "AssertionError did not contain {subgraph_name}" + + # Test wrong data type + dummy_copy = copy.copy(dummy_tensors) + tensor_key = f"{subgraph_name}_{tensor_name}" + dummy_copy[tensor_key] = dummy_copy[tensor_key].to(torch.float16) + + with pytest.raises(AssertionError) as assertinfo: + FlatWeatherGraph(**dummy_copy) + assert subgraph_name in str( + assertinfo + ), "AssertionError did not contain {subgraph_name}" + + +@pytest.mark.parametrize("subgraph_name", ["g2m", "m2g", "m2m"]) +def test_shape_match_flat_graph(subgraph_name): + """ + Test that shape mismatch between features and edge index + properly raises errors + """ + dummy_tensors = create_dummy_graph_tensors() + + tensor_key = f"{subgraph_name}_edge_features" + original_shape = dummy_tensors[tensor_key].shape + new_shape = (original_shape[0] + 1, original_shape[1]) + dummy_tensors[tensor_key] = torch.zeros(*new_shape, dtype=torch.float32) + + with pytest.raises(AssertionError) as assertinfo: + FlatWeatherGraph(**dummy_tensors) + assert subgraph_name in str( + assertinfo + ), "AssertionError did not contain {subgraph_name}" From 01193f5367cceaec47c2a15cb127aa5b1c1b51d0 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Tue, 16 Jul 2024 21:20:54 +0200 Subject: [PATCH 4/8] Implement node counts based on g2m indices --- neural_lam/graphs/base_weather_graph.py | 39 ++++++++++++-- neural_lam/graphs/flat_weather_graph.py | 10 ++-- tests/test_graph_classes.py | 67 +++++++++++++++++++------ 3 files changed, 92 insertions(+), 24 deletions(-) diff --git a/neural_lam/graphs/base_weather_graph.py b/neural_lam/graphs/base_weather_graph.py index 66a5ad30..1566be2a 100644 --- a/neural_lam/graphs/base_weather_graph.py +++ b/neural_lam/graphs/base_weather_graph.py @@ -4,7 +4,7 @@ # Third-party import torch -import torch.nn as nn +from torch import nn @dataclass @@ -26,8 +26,20 @@ def __post_init__(self): self.m2g_edge_features, self.m2g_edge_index, "m2g" ) - # TODO Checks that node indices align - # TODO Make all node indices start at 0 + # Make all node indices start at 0, if not + # TODO Remove from Inets + self.g2m_edge_index = self._reindex_edge_index(self.g2m_edge_index) + self.m2g_edge_index = self._reindex_edge_index(self.m2g_edge_index) + + @staticmethod + def _reindex_edge_index(edge_index): + """ + Create a version of edge_index with both sender and receiver indices + starting at 0. + + edge_index: (2, num_edges) tensor with edge index + """ + return edge_index - edge_index.min(dim=1, keepdim=True)[0] @staticmethod def check_features(features, subgraph_name): @@ -93,9 +105,26 @@ def check_subgraph(edge_features, edge_index, subgraph_name): f"(edge_index.shape) and features {edge_features.shape}" ) + # TODO Checks that node indices align between edge_index and features + + # TODO Cache? + # @functools.cached_property + @property + def num_grid_nodes(self): + """ + Get the number of grid nodes (grid cells) that the graph + is constructed for. + """ + return self.g2m_edge_index[0].max().item() + 1 + + # TODO Cache? + # @functools.cached_property + @property def num_mesh_nodes(self): - # TODO use g2m - pass + """ + Get the number of nodes in the mesh graph + """ + return self.g2m_edge_index[1].max().item() + 1 @staticmethod def from_graph_dir(path): diff --git a/neural_lam/graphs/flat_weather_graph.py b/neural_lam/graphs/flat_weather_graph.py index 623b199e..2f1c8ee7 100644 --- a/neural_lam/graphs/flat_weather_graph.py +++ b/neural_lam/graphs/flat_weather_graph.py @@ -3,6 +3,7 @@ # Third-party import torch +import torch_geometric as pyg # First-party from neural_lam.graphs.base_weather_graph import BaseWeatherGraph @@ -24,11 +25,12 @@ def __post_init__(self): self.m2m_edge_features, self.m2m_edge_index, "m2m" ) BaseWeatherGraph.check_features(self.mesh_node_features, "mesh nodes") - # TODO Checks that node indices align - def num_mesh_nodes(self): - # TODO use mesh_node_features - pass + # Check that m2m has correct properties + assert not pyg.utils.contains_isolated_nodes( + self.m2m_edge_index + ), "m2m_edge_index does not specify a connected graph" + # TODO Checks that node indices align with number of mesh nodes @staticmethod def from_graph_dir(path): diff --git a/tests/test_graph_classes.py b/tests/test_graph_classes.py index bf2549d7..744bfe7e 100644 --- a/tests/test_graph_classes.py +++ b/tests/test_graph_classes.py @@ -4,34 +4,61 @@ # Third-party import pytest import torch +import torch_geometric as pyg # First-party from neural_lam.graphs.flat_weather_graph import FlatWeatherGraph +NUM_GRID = 10 +NUM_MESH = 5 +FEATURE_DIM = 3 + def create_dummy_graph_tensors(): """ - Create dummy tensors for instantiating a flat graph + Create dummy tensors for instantiating a flat graph. + In the dummy tensors all grid nodes connect to all mesh nodes in g2m and m2g + (complete bipartite graph). + m2m is a complete graph. """ - num_grid = 10 - num_mesh = 5 - feature_dim = 3 - return { - "g2m_edge_index": torch.zeros(2, num_grid, dtype=torch.long), + "g2m_edge_index": torch.stack( + ( + torch.repeat_interleave(torch.arange(NUM_GRID), NUM_MESH), + torch.arange(NUM_MESH).repeat(NUM_GRID), + ), + dim=0, + ), "g2m_edge_features": ( - torch.zeros(num_grid, feature_dim, dtype=torch.float32) + torch.zeros(NUM_GRID * NUM_MESH, FEATURE_DIM, dtype=torch.float32) + ), + "m2g_edge_index": torch.stack( + ( + torch.arange(NUM_MESH).repeat(NUM_GRID), + torch.repeat_interleave(torch.arange(NUM_GRID), NUM_MESH), + ), + dim=0, ), - "m2g_edge_index": torch.zeros(2, num_grid, dtype=torch.long), "m2g_edge_features": ( - torch.zeros(num_grid, feature_dim, dtype=torch.float32) + torch.zeros(NUM_GRID * NUM_MESH, FEATURE_DIM, dtype=torch.float32) ), - "m2m_edge_index": torch.zeros(2, num_mesh, dtype=torch.long), + "m2m_edge_index": pyg.utils.remove_self_loops( + torch.stack( + ( + torch.repeat_interleave(torch.arange(NUM_MESH), NUM_MESH), + torch.arange(NUM_MESH).repeat(NUM_MESH), + ), + dim=0, + ) + )[0], "m2m_edge_features": ( - torch.zeros(num_mesh, feature_dim, dtype=torch.float32) + # Number of edges in complete graph of N nodes is N(N-1) + torch.zeros( + NUM_MESH * (NUM_MESH - 1), FEATURE_DIM, dtype=torch.float32 + ) ), "mesh_node_features": ( - torch.zeros(num_mesh, feature_dim, dtype=torch.float32) + torch.zeros(NUM_MESH, FEATURE_DIM, dtype=torch.float32) ), } @@ -40,7 +67,17 @@ def test_create_flat_graph(): """ Test that a Flat weather graph can be created with correct tensors """ - FlatWeatherGraph(**create_dummy_graph_tensors()) + graph = FlatWeatherGraph(**create_dummy_graph_tensors()) + + # Check that node counts are correct + assert graph.num_grid_nodes == NUM_GRID, ( + "num_grid_nodes returns wrong number of grid nodes: " + f"{graph.num_grid_nodes} (true number is {NUM_GRID})" + ) + assert graph.num_mesh_nodes == NUM_MESH, ( + "num_mesh_nodes returns wrong number of mesh nodes: " + f"{graph.num_mesh_nodes} (true number is {NUM_MESH})" + ) @pytest.mark.parametrize( @@ -66,7 +103,7 @@ def test_dtypes_flat_graph(subgraph_name, tensor_name): FlatWeatherGraph(**dummy_copy) assert subgraph_name in str( assertinfo - ), "AssertionError did not contain {subgraph_name}" + ), f"AssertionError did not contain {subgraph_name}" # Test wrong data type dummy_copy = copy.copy(dummy_tensors) @@ -77,7 +114,7 @@ def test_dtypes_flat_graph(subgraph_name, tensor_name): FlatWeatherGraph(**dummy_copy) assert subgraph_name in str( assertinfo - ), "AssertionError did not contain {subgraph_name}" + ), f"AssertionError did not contain {subgraph_name}" @pytest.mark.parametrize("subgraph_name", ["g2m", "m2g", "m2m"]) From beabddbdf251079e49780412c395bb03c2b1cf10 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Thu, 25 Jul 2024 16:46:50 +0200 Subject: [PATCH 5/8] Freeze dataclasses and cache node count properties --- neural_lam/graphs/base_weather_graph.py | 30 ++++++++++++++++--------- neural_lam/graphs/flat_weather_graph.py | 12 +++++++++- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/neural_lam/graphs/base_weather_graph.py b/neural_lam/graphs/base_weather_graph.py index 1566be2a..935c2996 100644 --- a/neural_lam/graphs/base_weather_graph.py +++ b/neural_lam/graphs/base_weather_graph.py @@ -1,4 +1,5 @@ # Standard library +import functools import os from dataclasses import dataclass @@ -7,7 +8,7 @@ from torch import nn -@dataclass +@dataclass(frozen=True) class BaseWeatherGraph(nn.Module): """ Graph object representing weather graph consisting of grid and mesh nodes @@ -27,9 +28,20 @@ def __post_init__(self): ) # Make all node indices start at 0, if not - # TODO Remove from Inets - self.g2m_edge_index = self._reindex_edge_index(self.g2m_edge_index) - self.m2g_edge_index = self._reindex_edge_index(self.m2g_edge_index) + # Use of setattr hereduring initialization, as dataclass is frozen. + # This matches dataclass behavior used in generated __init__ + # https://docs.python.org/3/library/dataclasses.html#frozen-instances + # TODO Remove reindexing from from Inets + object.__setattr__( + self, + "g2m_edge_index", + self._reindex_edge_index(self.g2m_edge_index), + ) + object.__setattr__( + self, + "m2g_edge_index", + self._reindex_edge_index(self.m2g_edge_index), + ) @staticmethod def _reindex_edge_index(edge_index): @@ -107,23 +119,21 @@ def check_subgraph(edge_features, edge_index, subgraph_name): # TODO Checks that node indices align between edge_index and features - # TODO Cache? - # @functools.cached_property - @property + @functools.cached_property def num_grid_nodes(self): """ Get the number of grid nodes (grid cells) that the graph is constructed for. """ + # Assumes all grid nodes connected to grid return self.g2m_edge_index[0].max().item() + 1 - # TODO Cache? - # @functools.cached_property - @property + @functools.cached_property def num_mesh_nodes(self): """ Get the number of nodes in the mesh graph """ + # Assumes all mesh nodes connected to grid return self.g2m_edge_index[1].max().item() + 1 @staticmethod diff --git a/neural_lam/graphs/flat_weather_graph.py b/neural_lam/graphs/flat_weather_graph.py index 2f1c8ee7..b56a0c90 100644 --- a/neural_lam/graphs/flat_weather_graph.py +++ b/neural_lam/graphs/flat_weather_graph.py @@ -1,4 +1,5 @@ # Standard library +import functools from dataclasses import dataclass # Third-party @@ -9,7 +10,7 @@ from neural_lam.graphs.base_weather_graph import BaseWeatherGraph -@dataclass +@dataclass(frozen=True) class FlatWeatherGraph(BaseWeatherGraph): """ Graph object representing weather graph consisting of grid and mesh nodes @@ -32,6 +33,15 @@ def __post_init__(self): ), "m2m_edge_index does not specify a connected graph" # TODO Checks that node indices align with number of mesh nodes + @functools.cached_property + def num_mesh_nodes(self): + """ + Get the number of nodes in the mesh graph + """ + # Override to determine in more robust way + # No longer assumes all mesh nodes connected to grid + return self.mesh_node_features.shape[0] + @staticmethod def from_graph_dir(path): """ From 205fc1047e8a8b64995d5dbda68e0499f18c9304 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Sun, 28 Jul 2024 12:28:32 +0200 Subject: [PATCH 6/8] Add string representations of graph classes --- neural_lam/graphs/base_weather_graph.py | 32 +++++++++++++++++++++---- neural_lam/graphs/flat_weather_graph.py | 15 ++++++++++++ tests/test_graph_classes.py | 22 +++++++++++++++++ 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/neural_lam/graphs/base_weather_graph.py b/neural_lam/graphs/base_weather_graph.py index 935c2996..88f2f849 100644 --- a/neural_lam/graphs/base_weather_graph.py +++ b/neural_lam/graphs/base_weather_graph.py @@ -117,8 +117,6 @@ def check_subgraph(edge_features, edge_index, subgraph_name): f"(edge_index.shape) and features {edge_features.shape}" ) - # TODO Checks that node indices align between edge_index and features - @functools.cached_property def num_grid_nodes(self): """ @@ -136,6 +134,20 @@ def num_mesh_nodes(self): # Assumes all mesh nodes connected to grid return self.g2m_edge_index[1].max().item() + 1 + @functools.cached_property + def num_m2g_edges(self): + """ + Get the number of edges in the grid-to-mesh graph + """ + return self.g2m_edge_index.shape[1] + + @functools.cached_property + def num_g2m_edges(self): + """ + Get the number of edges in the mesh-to-grid graph + """ + return self.m2g_edge_index.shape[1] + @staticmethod def from_graph_dir(path): """ @@ -197,6 +209,16 @@ def _load_graph_tensor(graph_dir_path, file_name): """ return torch.load(os.path.join(graph_dir_path, file_name)) - def __str__(): - # TODO Get from graph model init functions - pass + def __str__(self): + """ + Returns a string representation of the graph, including the total + number of nodes and the breakdown of nodes and edges in subgraphs. + """ + total_nodes = self.num_grid_nodes + self.num_mesh_nodes + return ( + f"Graph with {total_nodes} nodes ({self.num_grid_nodes} grid, " + f"{self.num_mesh_nodes} mesh)\n" + f"Subgraphs:\n" + f"g2m with {self.num_g2m_edges} edges\n" + f"m2g with {self.num_m2g_edges} edges" + ) diff --git a/neural_lam/graphs/flat_weather_graph.py b/neural_lam/graphs/flat_weather_graph.py index b56a0c90..3da5d272 100644 --- a/neural_lam/graphs/flat_weather_graph.py +++ b/neural_lam/graphs/flat_weather_graph.py @@ -42,6 +42,13 @@ def num_mesh_nodes(self): # No longer assumes all mesh nodes connected to grid return self.mesh_node_features.shape[0] + @functools.cached_property + def num_m2m_edges(self): + """ + Get the number of edges in the mesh graph + """ + return self.m2m_edge_index.shape[1] + @staticmethod def from_graph_dir(path): """ @@ -78,3 +85,11 @@ def from_graph_dir(path): m2m_edge_features, mesh_node_features, ) + + def __str__(self): + """ + Returns a string representation of the graph, including the total + number of nodes and the breakdown of nodes and edges in subgraphs. + """ + return super().__str__() + f"\nm2m with {self.num_m2m_edges} edges" + diff --git a/tests/test_graph_classes.py b/tests/test_graph_classes.py index 744bfe7e..f7a9cb0a 100644 --- a/tests/test_graph_classes.py +++ b/tests/test_graph_classes.py @@ -135,3 +135,25 @@ def test_shape_match_flat_graph(subgraph_name): assert subgraph_name in str( assertinfo ), "AssertionError did not contain {subgraph_name}" + + +def test_create_graph_str_rep(): + """ + Test that string representation of graph is correct + """ + graph = FlatWeatherGraph(**create_dummy_graph_tensors()) + str_rep = str(graph) + # Simple test that all relevant numbers are present + assert ( + str(NUM_GRID) in str_rep + ), "Correct number of grid nodes not in string representation of graph" + assert ( + str(NUM_MESH) in str_rep + ), "Correct number of mesh nodes not in string representation of graph" + + assert ( + str(NUM_MESH * NUM_GRID) in str_rep + ), "Correct number of g2m/m2g edges not in string representation of graph" + assert ( + str(NUM_MESH * (NUM_MESH - 1)) in str_rep + ), "Correct number of m2m edges not in string representation of graph" From 09c40ed4267013a63a6296a8c67246cee1ea1c27 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Sun, 28 Jul 2024 15:18:53 +0200 Subject: [PATCH 7/8] Add mesh node number consistency checks --- neural_lam/graphs/flat_weather_graph.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/neural_lam/graphs/flat_weather_graph.py b/neural_lam/graphs/flat_weather_graph.py index 3da5d272..df2a0ebb 100644 --- a/neural_lam/graphs/flat_weather_graph.py +++ b/neural_lam/graphs/flat_weather_graph.py @@ -31,7 +31,22 @@ def __post_init__(self): assert not pyg.utils.contains_isolated_nodes( self.m2m_edge_index ), "m2m_edge_index does not specify a connected graph" - # TODO Checks that node indices align with number of mesh nodes + + # Check that number of mesh nodes is consistent in node and edge sets + g2m_num = self.g2m_edge_index[1].max().item() + 1 + m2g_num = self.m2g_edge_index[0].max().item() + 1 + m2m_num_from = self.m2m_edge_index[0].max().item() + 1 + m2m_num_to = self.m2m_edge_index[1].max().item() + 1 + for edge_num_mesh, edge_description in ( + (g2m_num, "g2m edges"), + (m2g_num, "m2g edges"), + (m2m_num_from, "m2m edges (senders)"), + (m2m_num_to, "m2m edges (receivers)"), + ): + assert edge_num_mesh == self.num_mesh_nodes, ( + "Different number of mesh nodes represented by: " + f"{edge_description} and mesh node features" + ) @functools.cached_property def num_mesh_nodes(self): @@ -92,4 +107,3 @@ def __str__(self): number of nodes and the breakdown of nodes and edges in subgraphs. """ return super().__str__() + f"\nm2m with {self.num_m2m_edges} edges" - From c097c582a3329c405ffc0879e7c27582d6735ab1 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Wed, 14 Aug 2024 13:58:03 +0200 Subject: [PATCH 8/8] Add skeleton of hierarchical graph class --- .../graphs/hierarchical_weather_graph.py | 150 ++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 neural_lam/graphs/hierarchical_weather_graph.py diff --git a/neural_lam/graphs/hierarchical_weather_graph.py b/neural_lam/graphs/hierarchical_weather_graph.py new file mode 100644 index 00000000..00197a09 --- /dev/null +++ b/neural_lam/graphs/hierarchical_weather_graph.py @@ -0,0 +1,150 @@ +# Standard library +import functools +from dataclasses import dataclass + +# Third-party +import torch +import torch_geometric as pyg + +# First-party +from neural_lam.graphs.base_weather_graph import BaseWeatherGraph +from utils import BufferList + + +@dataclass(frozen=True) +class HierarchicalWeatherGraph(BaseWeatherGraph): + """ + Graph object representing weather graph consisting of grid and hierarchy of + mesh nodes. + """ + + mesh_up_edge_index: list + mesh_down_edge_index: list + mesh_intra_edge_index: list + mesh_up_features: list + mesh_down_features: list + mesh_intra_features: list + mesh_level_node_features: list + + def __post_init__(self): + super().__post_init__() + + # Put all hierarchical components in BufferList + for component_name in ( + mesh_up_edge_index, + mesh_down_edge_index, + mesh_intra_edge_index, + mesh_up_features, + mesh_down_features, + mesh_intra_features, + mesh_level_node_features, + ): + object.__setattr__( + self, + component_name + BufferList(getattr(self, component_name), persistent=False) + ) + + # TODO + BaseWeatherGraph.check_subgraph( + self.m2m_edge_features, self.m2m_edge_index, "m2m" + ) + BaseWeatherGraph.check_features(self.mesh_node_features, "mesh nodes") + + # Check that m2m has correct properties + assert not pyg.utils.contains_isolated_nodes( + self.m2m_edge_index + ), "m2m_edge_index does not specify a connected graph" + + # Check that number of mesh nodes is consistent in node and edge sets + g2m_num = self.g2m_edge_index[1].max().item() + 1 + m2g_num = self.m2g_edge_index[0].max().item() + 1 + m2m_num_from = self.m2m_edge_index[0].max().item() + 1 + m2m_num_to = self.m2m_edge_index[1].max().item() + 1 + for edge_num_mesh, edge_description in ( + (g2m_num, "g2m edges"), + (m2g_num, "m2g edges"), + (m2m_num_from, "m2m edges (senders)"), + (m2m_num_to, "m2m edges (receivers)"), + ): + assert edge_num_mesh == self.num_mesh_nodes, ( + "Different number of mesh nodes represented by: " + f"{edge_description} and mesh node features" + ) + + @functools.cached_property + def num_mesh_nodes(self): + # TODO What does this mean for a hierarchical graph? + """ + Get the number of nodes in the mesh graph + """ + # Override to determine in more robust way + # No longer assumes all mesh nodes connected to grid + return self.mesh_node_features.shape[0] + + @functools.cached_property + def num_m2m_edges(self): + # TODO What does this mean for a hierarchical graph? + """ + Get the number of edges in the mesh graph + """ + return self.m2m_edge_index.shape[1] + + @staticmethod + def from_graph_dir(path): + """ + Create WeatherGraph from tensors stored in a graph directory + + path: str, path to directory where graph data is stored + """ + # Load base grpah (g2m and m2g) + base_graph = BaseWeatherGraph.from_graph_dir(path) + + # Load m2m + ( + mesh_intra_edge_index, + mesh_intra_edge_features, + ) = BaseWeatherGraph._load_subgraph_from_dir( + path, "mesh_intra" + ) # List of (2, M_intra[l]), list of (M_intra[l], d_edge_features) + + # Load static mesh node features + mesh_level_node_features = BaseWeatherGraph._load_feature_tensor( + path, "mesh_features.pt" + ) # List of (N_mesh, d_node_features) + + # Load up and down edges and features + ( + mesh_up_edge_index, + mesh_up_edge_features, + ) = BaseWeatherGraph._load_subgraph_from_dir( + path, "mesh_up" + ) # List of (2, M_up[l]), list of (M_up[l], d_edge_features) + ( + mesh_down_edge_index, + mesh_down_edge_features, + ) = BaseWeatherGraph._load_subgraph_from_dir( + path, "mesh_down" + ) # List of (2, M_down[l]), list of (M_down[l], d_edge_features) + + return HierarchicalWeatherGraph( + base_graph.g2m_edge_index, + base_graph.g2m_edge_features, + base_graph.m2g_edge_index, + base_graph.m2g_edge_features, + mesh_up_edge_index, + mesh_down_edge_index, + mesh_intra_edge_index, + mesh_up_edge_features, + mesh_down_edge_features, + mesh_intra_edge_features, + mesh_level_node_features, + ) + + def __str__(self): + """ + Returns a string representation of the graph, including the total + number of nodes and the breakdown of nodes and edges in subgraphs. + """ + # TODO + return super().__str__() + f"\nm2m with {self.num_m2m_edges} edges"