-
-
Notifications
You must be signed in to change notification settings - Fork 51
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add graphcast utils * Add graph builder * Rebase * Add gencast loss * Add loss test and exceptions * Fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add source * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
8276ac2
commit a3c24c3
Showing
13 changed files
with
1,831 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,3 +34,5 @@ dependencies: | |
- pysolar | ||
- pytorch-lightning | ||
- click | ||
- trimesh | ||
- rtree |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,3 +35,5 @@ dependencies: | |
- pysolar | ||
- pytorch-lightning | ||
- click | ||
- trimesh | ||
- rtree |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# GenCast |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
"""Main import for GenCast""" | ||
|
||
from .graph.graph_builder import GraphBuilder | ||
from .weighted_mse_loss import WeightedMSELoss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Utils for graph generation.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,304 @@ | ||
"""Build the three graphs for GenCast. | ||
The following code is a port of several components from GraphCast's original graph generation | ||
(https://github.com/google-deepmind/graphcast) to PyG and PyTorch. The graphs are: | ||
- g2m: grid to mesh. | ||
- mesh: icosphere refinement. | ||
- m2g: mesh to grid. | ||
- khop: k-hop neighbours mesh. | ||
""" | ||
|
||
import numpy as np | ||
import torch | ||
from torch_geometric.data import Data, HeteroData | ||
from torch_geometric.transforms import TwoHop | ||
|
||
from graph_weather.models.gencast.graph import grid_mesh_connectivity, icosahedral_mesh, model_utils | ||
|
||
# Some configs from graphcast: | ||
_spatial_features_kwargs = dict( | ||
add_node_positions=False, | ||
add_node_latitude=True, | ||
add_node_longitude=True, | ||
add_relative_positions=True, | ||
relative_longitude_local_coordinates=True, | ||
relative_latitude_local_coordinates=True, | ||
) | ||
|
||
# radius_query_fraction_edge_length: Scalar that will be multiplied by the | ||
# length of the longest edge of the finest mesh to define the radius of | ||
# connectivity to use in the Grid2Mesh graph. Reasonable values are | ||
# between 0.6 and 1. 0.6 reduces the number of grid points feeding into | ||
# multiple mesh nodes and therefore reduces edge count and memory use, but | ||
# 1 gives better predictions. | ||
# mesh2grid_edge_normalization_factor: Allows explicitly controlling edge | ||
# normalization for mesh2grid edges. If None, defaults to max edge length. | ||
# This supports using pre-trained model weights with a different graph | ||
# structure to what it was trained on. | ||
|
||
radius_query_fraction_edge_length = 0.6 | ||
mesh2grid_edge_normalization_factor = None | ||
|
||
|
||
def _get_max_edge_distance(mesh): | ||
senders, receivers = icosahedral_mesh.faces_to_edges(mesh.faces) | ||
edge_distances = np.linalg.norm(mesh.vertices[senders] - mesh.vertices[receivers], axis=-1) | ||
return edge_distances.max() | ||
|
||
|
||
class GraphBuilder: | ||
""" | ||
Class for building GenCast's graphs. | ||
Attributes: | ||
g2m_graph (pyg.data.HeteroData): heterogeneous directed graph connecting the grid nodes | ||
to the mesh nodes. | ||
mesh_graph (pyg.data.Data): undirected graph connecting the mesh nodes. | ||
m2g_graph (pyg.data.HeteroData): heterogeneous directed graph connecting the mesh nodes | ||
to the grid nodes. | ||
khop_mesh_graph (pyg.data.Data): augmented version of mesh_graph in which every node is | ||
connected to its 2^num_hops neighbours. | ||
grid_nodes_dim (int): dimension of the grid nodes features. | ||
mesh_nodes_dim (int): dimension of the mesh nodes features. | ||
mesh_edges_dim (int): dimension of the mesh edges features. | ||
g2m_edges_dim (int): dimension of the "grid to mesh" edges features. | ||
m2g_edges_dim (int): dimension of the "mesh to grid" edges features. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
grid_lon: np.ndarray, | ||
grid_lat: np.ndarray, | ||
splits: int = 5, | ||
num_hops: int = 0, | ||
device: torch.device = torch.device("cpu"), | ||
): | ||
"""Initialize the GraphBuilder object. | ||
Args: | ||
grid_lon: 1D np.ndarray containing the list of longitudes. | ||
grid_lat: 1D np.ndarray containing the list of latitudes. | ||
splits: number of times to split the icosphere to build the mesh. Defaults to 5. | ||
num_hops: if num_hops=k then khop_mesh_graph will be the 2^k-neighbours version of | ||
the mesh. Defaults to 0. | ||
device: the device to which the graph will be moved. | ||
""" | ||
|
||
self._spatial_features_kwargs = _spatial_features_kwargs | ||
self.add_edge_features_to_khop = True | ||
self.device = device | ||
|
||
# Specification of the mesh. | ||
_icosahedral_refinements = icosahedral_mesh.get_hierarchy_of_triangular_meshes_for_sphere( | ||
splits | ||
) | ||
self._mesh = _icosahedral_refinements[-1] | ||
|
||
# Obtain the query radius in absolute units for the unit-sphere for the | ||
# grid2mesh model, by rescaling the `radius_query_fraction_edge_length`. | ||
self._query_radius = _get_max_edge_distance(self._mesh) * radius_query_fraction_edge_length | ||
self._mesh2grid_edge_normalization_factor = mesh2grid_edge_normalization_factor | ||
|
||
self.grid_nodes_dim = None | ||
self.mesh_nodes_dim = None | ||
self.mesh_edges_dim = None | ||
self.g2m_edges_dim = None | ||
self.m2g_edges_dim = None | ||
|
||
# A "_init_mesh_properties": | ||
# This one could be initialized at init but we delay it for consistency too. | ||
self._num_mesh_nodes = None # num_mesh_nodes | ||
self._mesh_nodes_lat = None # [num_mesh_nodes] | ||
self._mesh_nodes_lon = None # [num_mesh_nodes] | ||
|
||
# A "_init_grid_properties": | ||
self._grid_lat = None # [num_lat_points] | ||
self._grid_lon = None # [num_lon_points] | ||
self._num_grid_nodes = None # num_lat_points * num_lon_points | ||
self._grid_nodes_lat = None # [num_grid_nodes] | ||
self._grid_nodes_lon = None # [num_grid_nodes] | ||
|
||
self._init_grid_properties(grid_lat, grid_lon) | ||
self._init_mesh_properties() | ||
self.g2m_graph = self._init_grid2mesh_graph() | ||
self.mesh_graph = self._init_mesh_graph() | ||
self.m2g_graph = self._init_mesh2grid_graph() | ||
|
||
self.num_hops = num_hops | ||
self.khop_mesh_graph = self._init_khop_mesh_graph() | ||
|
||
def _init_grid_properties(self, grid_lat: np.ndarray, grid_lon: np.ndarray): | ||
"""Inits static properties that have to do with grid nodes.""" | ||
self._grid_lat = grid_lat.astype(np.float32) | ||
self._grid_lon = grid_lon.astype(np.float32) | ||
# Initialized the counters. | ||
self._num_grid_nodes = grid_lat.shape[0] * grid_lon.shape[0] | ||
|
||
# Initialize lat and lon for the grid. | ||
grid_nodes_lon, grid_nodes_lat = np.meshgrid(grid_lon, grid_lat) | ||
self._grid_nodes_lon = grid_nodes_lon.reshape([-1]).astype(np.float32) | ||
self._grid_nodes_lat = grid_nodes_lat.reshape([-1]).astype(np.float32) | ||
|
||
def _init_mesh_properties(self): | ||
"""Inits static properties that have to do with mesh nodes.""" | ||
self._num_mesh_nodes = self._mesh.vertices.shape[0] | ||
mesh_phi, mesh_theta = model_utils.cartesian_to_spherical( | ||
self._mesh.vertices[:, 0], self._mesh.vertices[:, 1], self._mesh.vertices[:, 2] | ||
) | ||
( | ||
mesh_nodes_lat, | ||
mesh_nodes_lon, | ||
) = model_utils.spherical_to_lat_lon(phi=mesh_phi, theta=mesh_theta) | ||
# Convert to f32 to ensure the lat/lon features aren't in f64. | ||
self._mesh_nodes_lat = mesh_nodes_lat.astype(np.float32) | ||
self._mesh_nodes_lon = mesh_nodes_lon.astype(np.float32) | ||
|
||
def _init_grid2mesh_graph(self): | ||
"""Build Grid2Mesh graph.""" | ||
|
||
# Create some edges according to distance between mesh and grid nodes. | ||
assert self._grid_lat is not None and self._grid_lon is not None | ||
(grid_indices, mesh_indices) = grid_mesh_connectivity.radius_query_indices( | ||
grid_latitude=self._grid_lat, | ||
grid_longitude=self._grid_lon, | ||
mesh=self._mesh, | ||
radius=self._query_radius, | ||
) | ||
|
||
# Edges sending info from grid to mesh. | ||
senders = grid_indices | ||
receivers = mesh_indices | ||
|
||
# Precompute structural node and edge features according to config options. | ||
# Structural features are those that depend on the fixed values of the | ||
# latitude and longitudes of the nodes. | ||
(senders_node_features, receivers_node_features, edge_features) = ( | ||
model_utils.get_bipartite_graph_spatial_features( | ||
senders_node_lat=self._grid_nodes_lat, | ||
senders_node_lon=self._grid_nodes_lon, | ||
receivers_node_lat=self._mesh_nodes_lat, | ||
receivers_node_lon=self._mesh_nodes_lon, | ||
senders=senders, | ||
receivers=receivers, | ||
edge_normalization_factor=None, | ||
**self._spatial_features_kwargs, | ||
) | ||
) | ||
|
||
self.grid_nodes_dim = senders_node_features.shape[1] | ||
self.mesh_nodes_dim = receivers_node_features.shape[1] | ||
self.g2m_edges_dim = edge_features.shape[1] | ||
|
||
g2m_graph = HeteroData() | ||
g2m_graph["grid_nodes"].x = torch.tensor( | ||
senders_node_features, dtype=torch.float32, device=self.device | ||
) # TODO: generate graph with torch or np? | ||
g2m_graph["mesh_nodes"].x = torch.tensor( | ||
receivers_node_features, dtype=torch.float32, device=self.device | ||
) | ||
g2m_graph["grid_nodes", "to", "mesh_nodes"].edge_index = torch.tensor( | ||
np.stack([senders, receivers]), dtype=torch.long, device=self.device | ||
) | ||
g2m_graph["grid_nodes", "to", "mesh_nodes"].edge_attr = torch.tensor( | ||
edge_features, dtype=torch.float32, device=self.device | ||
) | ||
|
||
return g2m_graph | ||
|
||
def _init_mesh_graph(self): | ||
"""Build Mesh graph.""" | ||
# Work simply on the mesh edges. | ||
senders, receivers = icosahedral_mesh.faces_to_edges(self._mesh.faces) | ||
|
||
# Precompute structural node and edge features according to config options. | ||
# Structural features are those that depend on the fixed values of the | ||
# latitude and longitudes of the nodes. | ||
assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None | ||
node_features, edge_features = model_utils.get_graph_spatial_features( | ||
node_lat=self._mesh_nodes_lat, | ||
node_lon=self._mesh_nodes_lon, | ||
senders=senders, | ||
receivers=receivers, | ||
**self._spatial_features_kwargs, | ||
) | ||
|
||
self.mesh_edges_dim = edge_features.shape[1] | ||
|
||
mesh_graph = Data( | ||
x=torch.tensor(node_features, dtype=torch.float32, device=self.device), | ||
edge_attr=torch.tensor(edge_features, dtype=torch.float32, device=self.device), | ||
edge_index=torch.tensor( | ||
np.stack([senders, receivers]), dtype=torch.long, device=self.device | ||
), | ||
) | ||
|
||
return mesh_graph | ||
|
||
def _init_mesh2grid_graph(self): | ||
"""Build Mesh2Grid graph.""" | ||
|
||
# Create some edges according to how the grid nodes are contained by | ||
# mesh triangles. | ||
(grid_indices, mesh_indices) = grid_mesh_connectivity.in_mesh_triangle_indices( | ||
grid_latitude=self._grid_lat, grid_longitude=self._grid_lon, mesh=self._mesh | ||
) | ||
|
||
# Edges sending info from mesh to grid. | ||
senders = mesh_indices | ||
receivers = grid_indices | ||
|
||
# Precompute structural node and edge features according to config options. | ||
assert self._mesh_nodes_lat is not None and self._mesh_nodes_lon is not None | ||
(senders_node_features, receivers_node_features, edge_features) = ( | ||
model_utils.get_bipartite_graph_spatial_features( | ||
senders_node_lat=self._mesh_nodes_lat, | ||
senders_node_lon=self._mesh_nodes_lon, | ||
receivers_node_lat=self._grid_nodes_lat, | ||
receivers_node_lon=self._grid_nodes_lon, | ||
senders=senders, | ||
receivers=receivers, | ||
edge_normalization_factor=self._mesh2grid_edge_normalization_factor, | ||
**self._spatial_features_kwargs, | ||
) | ||
) | ||
|
||
self.m2g_edges_dim = edge_features.shape[1] | ||
|
||
m2g_graph = HeteroData() | ||
m2g_graph["mesh_nodes"].x = torch.tensor( | ||
senders_node_features, dtype=torch.float32, device=self.device | ||
) | ||
m2g_graph["grid_nodes"].x = torch.tensor( | ||
receivers_node_features, dtype=torch.float32, device=self.device | ||
) | ||
m2g_graph["mesh_nodes", "to", "grid_nodes"].edge_index = torch.tensor( | ||
np.stack([senders, receivers]), dtype=torch.long, device=self.device | ||
) | ||
m2g_graph["mesh_nodes", "to", "grid_nodes"].edge_attr = torch.tensor( | ||
edge_features, dtype=torch.float32, device=self.device | ||
) | ||
|
||
return m2g_graph | ||
|
||
def _init_khop_mesh_graph(self): | ||
"""Build k-hop Mesh graph.""" | ||
|
||
transform = TwoHop() | ||
khop_mesh_graph = self.mesh_graph | ||
for _ in range(self.num_hops): | ||
khop_mesh_graph = transform(khop_mesh_graph) | ||
|
||
if self.add_edge_features_to_khop: | ||
senders = khop_mesh_graph.edge_index[0] | ||
receivers = khop_mesh_graph.edge_index[1] | ||
_, edge_features = model_utils.get_graph_spatial_features( | ||
node_lat=self._mesh_nodes_lat, | ||
node_lon=self._mesh_nodes_lon, | ||
senders=senders, | ||
receivers=receivers, | ||
**self._spatial_features_kwargs, | ||
) | ||
khop_mesh_graph.edge_attr = torch.tensor( | ||
edge_features, dtype=torch.float32, device=self.device | ||
) | ||
return khop_mesh_graph |
Oops, something went wrong.