Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce graph classes for graph-based models #65

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ slurm_log*
saved_models
lightning_logs
data
graphs
graphs/*
*.sif
sweeps
test_*.sh
Expand Down
224 changes: 224 additions & 0 deletions neural_lam/graphs/base_weather_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# Standard library
import functools
import os
from dataclasses import dataclass

# Third-party
import torch
from torch import nn


@dataclass(frozen=True)
class BaseWeatherGraph(nn.Module):
"""
Graph object representing weather graph consisting of grid and mesh nodes
"""

g2m_edge_index: torch.Tensor
g2m_edge_features: torch.Tensor
m2g_edge_index: torch.Tensor
m2g_edge_features: torch.Tensor

def __post_init__(self):
BaseWeatherGraph.check_subgraph(
self.g2m_edge_features, self.g2m_edge_index, "g2m"
)
BaseWeatherGraph.check_subgraph(
self.m2g_edge_features, self.m2g_edge_index, "m2g"
)

# Make all node indices start at 0, if not
# Use of setattr hereduring initialization, as dataclass is frozen.
# This matches dataclass behavior used in generated __init__
# https://docs.python.org/3/library/dataclasses.html#frozen-instances
# TODO Remove reindexing from from Inets
object.__setattr__(
self,
"g2m_edge_index",
self._reindex_edge_index(self.g2m_edge_index),
)
object.__setattr__(
self,
"m2g_edge_index",
self._reindex_edge_index(self.m2g_edge_index),
)

@staticmethod
def _reindex_edge_index(edge_index):
"""
Create a version of edge_index with both sender and receiver indices
starting at 0.

edge_index: (2, num_edges) tensor with edge index
"""
return edge_index - edge_index.min(dim=1, keepdim=True)[0]

@staticmethod
def check_features(features, subgraph_name):
"""
Check that feature tensor has the correct format

features: (2, num_features) tensor of features
subgraph_name: name of associated subgraph, used in error messages
"""
assert isinstance(
features, torch.Tensor
), f"{subgraph_name} features is not a tensor"
assert features.dtype == torch.float32, (
f"Wrong data type for {subgraph_name} feature tensor: "
f"{features.dtype}"
)
assert len(features.shape) == 2, (
f"Wrong shape of {subgraph_name} feature tensor: "
f"{features.shape}"
)

@staticmethod
def check_edge_index(edge_index, subgraph_name):
"""
Check that edge index tensor has the correct format

edge_index: (2, num_edges) tensor with edge index
subgraph_name: name of associated subgraph, used in error messages
"""
assert isinstance(
edge_index, torch.Tensor
), f"{subgraph_name} edge_index is not a tensor"
assert edge_index.dtype == torch.int64, (
f"Wrong data type for {subgraph_name} edge_index "
f"tensor: {edge_index.dtype}"
)
assert len(edge_index.shape) == 2, (
f"Wrong shape of {subgraph_name} edge_index tensor: "
f"{edge_index.shape}"
)
assert edge_index.shape[0] == 2, (
"Wrong shape of {subgraph_name} edge_index tensor: "
f"{edge_index.shape}"
)

@staticmethod
def check_subgraph(edge_features, edge_index, subgraph_name):
"""
Check that tensors associated with subgraph (edge index and features)
has the correct format

edge_features: (2, num_features) tensor of edge features
edge_index: (2, num_edges) tensor with edge index
subgraph_name: name of associated subgraph, used in error messages
"""
# Check individual tensors
BaseWeatherGraph.check_features(edge_features, subgraph_name)
BaseWeatherGraph.check_edge_index(edge_index, subgraph_name)

# Check compatibility
assert edge_features.shape[0] == edge_index.shape[1], (
f"Mismatch in shape of {subgraph_name} edge_index "
f"(edge_index.shape) and features {edge_features.shape}"
)

@functools.cached_property
def num_grid_nodes(self):
"""
Get the number of grid nodes (grid cells) that the graph
is constructed for.
"""
# Assumes all grid nodes connected to grid
return self.g2m_edge_index[0].max().item() + 1

@functools.cached_property
def num_mesh_nodes(self):
"""
Get the number of nodes in the mesh graph
"""
# Assumes all mesh nodes connected to grid
return self.g2m_edge_index[1].max().item() + 1

@functools.cached_property
def num_m2g_edges(self):
"""
Get the number of edges in the grid-to-mesh graph
"""
return self.g2m_edge_index.shape[1]

@functools.cached_property
def num_g2m_edges(self):
"""
Get the number of edges in the mesh-to-grid graph
"""
return self.m2g_edge_index.shape[1]

@staticmethod
def from_graph_dir(path):
"""
Create WeatherGraph from tensors stored in a graph directory

path: str, path to directory where graph data is stored
"""
(
g2m_edge_index,
g2m_edge_features,
) = BaseWeatherGraph._load_subgraph_from_dir(
path, "g2m"
) # (2, M_g2m), (M_g2m, d_edge_features)
(
m2g_edge_index,
m2g_edge_features,
) = BaseWeatherGraph._load_subgraph_from_dir(
path, "m2g"
) # (2, M_m2g), (M_m2g, d_edge_features)

return BaseWeatherGraph(
g2m_edge_index,
g2m_edge_features,
m2g_edge_index,
m2g_edge_features,
)

@staticmethod
def _load_subgraph_from_dir(graph_dir_path, subgraph_name):
"""
Load edge_index + feature tensor from a graph directory,
for a specific subgraph
"""
edge_index = BaseWeatherGraph._load_graph_tensor(
graph_dir_path, f"{subgraph_name}_edge_index.pt"
)

edge_features = BaseWeatherGraph._load_feature_tensor(
graph_dir_path, f"{subgraph_name}_features.pt"
)

return edge_index, edge_features

@staticmethod
def _load_feature_tensor(graph_dir_path, file_name):
"""
Load feature tensor with from a graph directory
"""
features = BaseWeatherGraph._load_graph_tensor(
graph_dir_path, file_name
)

return features

@staticmethod
def _load_graph_tensor(graph_dir_path, file_name):
"""
Load graph tensor with edge_index or features from a graph directory
"""
return torch.load(os.path.join(graph_dir_path, file_name))

def __str__(self):
"""
Returns a string representation of the graph, including the total
number of nodes and the breakdown of nodes and edges in subgraphs.
"""
total_nodes = self.num_grid_nodes + self.num_mesh_nodes
return (
f"Graph with {total_nodes} nodes ({self.num_grid_nodes} grid, "
f"{self.num_mesh_nodes} mesh)\n"
f"Subgraphs:\n"
f"g2m with {self.num_g2m_edges} edges\n"
f"m2g with {self.num_m2g_edges} edges"
)
109 changes: 109 additions & 0 deletions neural_lam/graphs/flat_weather_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Standard library
import functools
from dataclasses import dataclass

# Third-party
import torch
import torch_geometric as pyg

# First-party
from neural_lam.graphs.base_weather_graph import BaseWeatherGraph


@dataclass(frozen=True)
class FlatWeatherGraph(BaseWeatherGraph):
"""
Graph object representing weather graph consisting of grid and mesh nodes
"""

m2m_edge_index: torch.Tensor
m2m_edge_features: torch.Tensor
mesh_node_features: torch.Tensor

def __post_init__(self):
super().__post_init__()
BaseWeatherGraph.check_subgraph(
self.m2m_edge_features, self.m2m_edge_index, "m2m"
)
BaseWeatherGraph.check_features(self.mesh_node_features, "mesh nodes")

# Check that m2m has correct properties
assert not pyg.utils.contains_isolated_nodes(
self.m2m_edge_index
), "m2m_edge_index does not specify a connected graph"

# Check that number of mesh nodes is consistent in node and edge sets
g2m_num = self.g2m_edge_index[1].max().item() + 1
m2g_num = self.m2g_edge_index[0].max().item() + 1
m2m_num_from = self.m2m_edge_index[0].max().item() + 1
m2m_num_to = self.m2m_edge_index[1].max().item() + 1
for edge_num_mesh, edge_description in (
(g2m_num, "g2m edges"),
(m2g_num, "m2g edges"),
(m2m_num_from, "m2m edges (senders)"),
(m2m_num_to, "m2m edges (receivers)"),
):
assert edge_num_mesh == self.num_mesh_nodes, (
"Different number of mesh nodes represented by: "
f"{edge_description} and mesh node features"
)

@functools.cached_property
def num_mesh_nodes(self):
"""
Get the number of nodes in the mesh graph
"""
# Override to determine in more robust way
# No longer assumes all mesh nodes connected to grid
return self.mesh_node_features.shape[0]

@functools.cached_property
def num_m2m_edges(self):
"""
Get the number of edges in the mesh graph
"""
return self.m2m_edge_index.shape[1]

@staticmethod
def from_graph_dir(path):
"""
Create WeatherGraph from tensors stored in a graph directory

path: str, path to directory where graph data is stored
"""
# Load base grpah (g2m and m2g)
base_graph = BaseWeatherGraph.from_graph_dir(path)

# Load m2m
(
m2m_edge_index,
m2m_edge_features,
) = BaseWeatherGraph._load_subgraph_from_dir(
path, "m2m"
) # (2, M_m2m), (M_m2m, d_edge_features)

# Load static mesh node features
mesh_node_features = BaseWeatherGraph._load_feature_tensor(
path, "mesh_features.pt"
) # (N_mesh, d_node_features)

# Note: We assume that graph features are already normalized
# when read from disk
# TODO ^ actually do this in graph creation

return FlatWeatherGraph(
base_graph.g2m_edge_index,
base_graph.g2m_edge_features,
base_graph.m2g_edge_index,
base_graph.m2g_edge_features,
m2m_edge_index,
m2m_edge_features,
mesh_node_features,
)

def __str__(self):
"""
Returns a string representation of the graph, including the total
number of nodes and the breakdown of nodes and edges in subgraphs.
"""
return super().__str__() + f"\nm2m with {self.num_m2m_edges} edges"
Loading
Loading