Skip to content

Commit

Permalink
Start on splitting out graph from Encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobbieker committed Nov 28, 2023
1 parent c7b85ae commit 0f98e3c
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 1 deletion.
58 changes: 58 additions & 0 deletions graph_weather/models/graphs/hexagonal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Generate hexagonal global grid using Uber's H3 library."""
import h3
import numpy as np
import torch
from torch_geometric.data import Data


def generate_hexagonal_grid(resolution: int = 2) -> np.ndarray:
Expand All @@ -15,3 +17,59 @@ def generate_hexagonal_grid(resolution: int = 2) -> np.ndarray:
base_h3_grid = sorted(list(h3.uncompact(h3.get_res0_indexes(), resolution)))
base_h3_map = {h_i: i for i, h_i in enumerate(base_h3_grid)}
return np.array(base_h3_grid), base_h3_map

def generate_h3_mapping(lat_lons: list, resolution: int = 2) -> dict:
"""Generate mapping from lat/lon to h3 index.
Args:
lat_lons: List of (lat,lon) points
resolution: H3 resolution level
"""
num_latlons = len(lat_lons)
base_h3_grid = sorted(list(h3.uncompact(h3.get_res0_indexes(), resolution)))
base_h3_map = {h_i: i for i, h_i in enumerate(base_h3_grid)}
h3_grid = [h3.geo_to_h3(lat, lon, resolution) for lat, lon in lat_lons]
h3_mapping = {}
h_index = len(base_h3_grid)
for h in base_h3_grid:
if h not in h3_mapping:
h_index -= 1
h3_mapping[h] = h_index + num_latlons
# Now have the h3 grid mapping, the bipartite graph of edges connecting lat/lon to h3 nodes
# Should have vertical and horizontal difference
h3_distances = []
for idx, h3_point in enumerate(h3_grid):
lat_lon = lat_lons[idx]
distance = h3.point_dist(lat_lon, h3.h3_to_geo(h3_point), unit="rads")
h3_distances.append([np.sin(distance), np.cos(distance)])
h3_distances = torch.tensor(h3_distances, dtype=torch.float)
return base_h3_map, h3_mapping, h3_distances


def generate_latent_h3_graph(base_h3_map: dict, base_h3_grid: dict) -> torch.Tensor:
"""Generate latent h3 graph.
Args:
base_h3_map: Mapping from h3 index to index in latent graph
h3_mapping: Mapping from lat/lon to h3 index
h3_distances: Distances between lat/lon and h3 index
Returns:
Latent h3 graph
"""
# Get connectivity of the graph
edge_sources = []
edge_targets = []
edge_attrs = []
for h3_index in base_h3_grid:
h_points = h3.k_ring(h3_index, 1)
for h in h_points: # Already includes itself
distance = h3.point_dist(h3.h3_to_geo(h3_index), h3.h3_to_geo(h), unit="rads")
edge_attrs.append([np.sin(distance), np.cos(distance)])
edge_sources.append(base_h3_map[h3_index])
edge_targets.append(base_h3_map[h])
edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long)
edge_attrs = torch.tensor(edge_attrs, dtype=torch.float)
# Use heterogeneous graph as input and output dims are not same for the encoder
# Because uniform grid now, don't need edge attributes as they are all the same
return Data(edge_index=edge_index, edge_attr=edge_attrs)
59 changes: 58 additions & 1 deletion graph_weather/models/graphs/ico.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
"""

import numpy as np
import torch
from torch_geometric.data import Data


def icosphere(nu=1, nr_verts=None):
Expand Down Expand Up @@ -294,5 +296,60 @@ def generate_icosphere_graph(resolution=1):
vertices, faces = icosphere(resolution)
edges = np.r_[faces[:, :-1], faces[:, 1:], faces[:, [0, 2]]]
edges = np.unique(np.sort(edges, axis=1), axis=0)
return NotImplementedError("TODO: Make into PyTorch Tensors and return")
return vertices, edges


def generate_icosphere_mapping(lat_lons, resolution=1):
"""
Generate mapping from lat/lon to icosphere index.
Args:
lat_lons: List of (lat,lon) points
resolution: Icosphere resolution level
"""
num_latlons = len(lat_lons)
vertices, faces = icosphere(resolution)
# TODO Actually make this work
h3_mapping = {}
h_index = len(vertices)
for h in vertices:
if h not in h3_mapping:
h_index -= 1
h3_mapping[h] = h_index + num_latlons
# Now have the h3 grid mapping, the bipartite graph of edges connecting lat/lon to h3 nodes
# Should have vertical and horizontal difference
h3_distances = []
for idx, h3_point in enumerate(h3_grid):
lat_lon = lat_lons[idx]
distance = h3.point_dist(lat_lon, h3.h3_to_geo(h3_point), unit="rads")
h3_distances.append([np.sin(distance), np.cos(distance)])
h3_distances = torch.tensor(h3_distances, dtype=torch.float)
return h3_mapping, h3_distances

def generate_latent_ico_graph(h3_mapping, h3_distances):
"""
Generate latent h3 graph.
Args:
base_h3_map: Mapping from h3 index to index in latent graph
h3_mapping: Mapping from lat/lon to h3 index
h3_distances: Distances between lat/lon and h3 index
Returns:
Latent h3 graph
"""
# Get connectivity of the graph
edge_sources = []
edge_targets = []
edge_attrs = []
for h3_index in h3_mapping:
h_points = h3.k_ring(h3_index, 1)
for h in h_points: # Already includes itself
distance = h3.point_dist(h3.h3_to_geo(h3_index), h3.h3_to_geo(h), unit="rads")
edge_attrs.append([np.sin(distance), np.cos(distance)])
edge_sources.append(h3_mapping[h3_index])
edge_targets.append(h3_mapping[h])
edge_sources = np.array(edge_sources)
edge_targets = np.array(edge_targets)
edge_attrs = np.array(edge_attrs)
return edge_sources, edge_targets, edge_attrs
4 changes: 4 additions & 0 deletions graph_weather/models/layers/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def __init__(
hidden_layers_processor_edge=2,
mlp_norm_type="LayerNorm",
use_checkpointing: bool = False,
input_graph: Data = None,
latent_graph: Data = None,
):
"""
Encode the lat/lon data inot the isohedron graph
Expand All @@ -65,6 +67,8 @@ def __init__(
mlp_norm_type: Type of norm for the MLPs
one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None
use_checkpointing: Whether to use gradient checkpointing to use less memory
input_graph: Input graph to use, if None, will generate a default one
latent_graph: Latent graph to use, if None, will generate a default one
"""
super().__init__()
self.use_checkpointing = use_checkpointing
Expand Down
112 changes: 112 additions & 0 deletions graph_weather/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,115 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor):
out = out * self.weights.expand_as(out)
assert not torch.isnan(out).any()
return out.mean()


# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import torch
import torch.nn as nn
from torch.autograd.function import once_differentiable


class CellAreaWeightedLossFunction(nn.Module):
"""Loss function with cell area weighting.
Parameters
----------
area : torch.Tensor
Cell area with shape [H, W].
"""

def __init__(self, area):
super().__init__()
self.area = area

def forward(self, invar, outvar):
"""
Implicit forward function which computes the loss given
a prediction and the corresponding targets.
Parameters
----------
invar : torch.Tensor
prediction of shape [T, C, H, W].
outvar : torch.Tensor
target values of shape [T, C, H, W].
"""

loss = (invar - outvar) ** 2
loss = loss.mean(dim=(0, 1))
loss = torch.mul(loss, self.area)
loss = loss.mean()
return loss


class CustomCellAreaWeightedLossAutogradFunction(torch.autograd.Function):
"""Autograd fuunction for custom loss with cell area weighting."""

@staticmethod
def forward(ctx, invar: torch.Tensor, outvar: torch.Tensor, area: torch.Tensor):
"""Forward of custom loss function with cell area weighting."""

diff = invar - outvar # T x C x H x W
loss = diff**2
loss = loss.mean(dim=(0, 1))
loss = torch.mul(loss, area)
loss = loss.mean()
loss_grad = diff * (2.0 / (math.prod(invar.shape)))
loss_grad *= area.unsqueeze(0).unsqueeze(0)
ctx.save_for_backward(loss_grad)
return loss

@staticmethod
@once_differentiable
def backward(ctx, grad_loss: torch.Tensor):
"""Backward method of custom loss function with cell area weighting."""

# grad_loss should be 1, multiply nevertheless
# to avoid issues with cases where this isn't the case
(grad_invar,) = ctx.saved_tensors
return grad_invar * grad_loss, None, None


class CustomCellAreaWeightedLossFunction(CellAreaWeightedLossFunction):
"""Custom loss function with cell area weighting.
Parameters
----------
area : torch.Tensor
Cell area with shape [H, W].
"""

def __init__(self, area: torch.Tensor):
super().__init__(area)

def forward(self, invar: torch.Tensor, outvar: torch.Tensor) -> torch.Tensor:
"""
Implicit forward function which computes the loss given
a prediction and the corresponding targets.
Parameters
----------
invar : torch.Tensor
prediction of shape [T, C, H, W].
outvar : torch.Tensor
target values of shape [T, C, H, W].
"""

return CustomCellAreaWeightedLossAutogradFunction.apply(
invar, outvar, self.area
)

0 comments on commit 0f98e3c

Please sign in to comment.