diff --git a/graph_weather/models/graphs/hexagonal.py b/graph_weather/models/graphs/hexagonal.py index 00ee82f5..aa9af20a 100644 --- a/graph_weather/models/graphs/hexagonal.py +++ b/graph_weather/models/graphs/hexagonal.py @@ -1,6 +1,8 @@ """Generate hexagonal global grid using Uber's H3 library.""" import h3 import numpy as np +import torch +from torch_geometric.data import Data def generate_hexagonal_grid(resolution: int = 2) -> np.ndarray: @@ -15,3 +17,59 @@ def generate_hexagonal_grid(resolution: int = 2) -> np.ndarray: base_h3_grid = sorted(list(h3.uncompact(h3.get_res0_indexes(), resolution))) base_h3_map = {h_i: i for i, h_i in enumerate(base_h3_grid)} return np.array(base_h3_grid), base_h3_map + +def generate_h3_mapping(lat_lons: list, resolution: int = 2) -> dict: + """Generate mapping from lat/lon to h3 index. + + Args: + lat_lons: List of (lat,lon) points + resolution: H3 resolution level + """ + num_latlons = len(lat_lons) + base_h3_grid = sorted(list(h3.uncompact(h3.get_res0_indexes(), resolution))) + base_h3_map = {h_i: i for i, h_i in enumerate(base_h3_grid)} + h3_grid = [h3.geo_to_h3(lat, lon, resolution) for lat, lon in lat_lons] + h3_mapping = {} + h_index = len(base_h3_grid) + for h in base_h3_grid: + if h not in h3_mapping: + h_index -= 1 + h3_mapping[h] = h_index + num_latlons + # Now have the h3 grid mapping, the bipartite graph of edges connecting lat/lon to h3 nodes + # Should have vertical and horizontal difference + h3_distances = [] + for idx, h3_point in enumerate(h3_grid): + lat_lon = lat_lons[idx] + distance = h3.point_dist(lat_lon, h3.h3_to_geo(h3_point), unit="rads") + h3_distances.append([np.sin(distance), np.cos(distance)]) + h3_distances = torch.tensor(h3_distances, dtype=torch.float) + return base_h3_map, h3_mapping, h3_distances + + +def generate_latent_h3_graph(base_h3_map: dict, base_h3_grid: dict) -> torch.Tensor: + """Generate latent h3 graph. + + Args: + base_h3_map: Mapping from h3 index to index in latent graph + h3_mapping: Mapping from lat/lon to h3 index + h3_distances: Distances between lat/lon and h3 index + + Returns: + Latent h3 graph + """ + # Get connectivity of the graph + edge_sources = [] + edge_targets = [] + edge_attrs = [] + for h3_index in base_h3_grid: + h_points = h3.k_ring(h3_index, 1) + for h in h_points: # Already includes itself + distance = h3.point_dist(h3.h3_to_geo(h3_index), h3.h3_to_geo(h), unit="rads") + edge_attrs.append([np.sin(distance), np.cos(distance)]) + edge_sources.append(base_h3_map[h3_index]) + edge_targets.append(base_h3_map[h]) + edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long) + edge_attrs = torch.tensor(edge_attrs, dtype=torch.float) + # Use heterogeneous graph as input and output dims are not same for the encoder + # Because uniform grid now, don't need edge attributes as they are all the same + return Data(edge_index=edge_index, edge_attr=edge_attrs) diff --git a/graph_weather/models/graphs/ico.py b/graph_weather/models/graphs/ico.py index b9302f69..90009fc7 100644 --- a/graph_weather/models/graphs/ico.py +++ b/graph_weather/models/graphs/ico.py @@ -32,6 +32,8 @@ """ import numpy as np +import torch +from torch_geometric.data import Data def icosphere(nu=1, nr_verts=None): @@ -294,5 +296,60 @@ def generate_icosphere_graph(resolution=1): vertices, faces = icosphere(resolution) edges = np.r_[faces[:, :-1], faces[:, 1:], faces[:, [0, 2]]] edges = np.unique(np.sort(edges, axis=1), axis=0) - return NotImplementedError("TODO: Make into PyTorch Tensors and return") return vertices, edges + + +def generate_icosphere_mapping(lat_lons, resolution=1): + """ + Generate mapping from lat/lon to icosphere index. + + Args: + lat_lons: List of (lat,lon) points + resolution: Icosphere resolution level + """ + num_latlons = len(lat_lons) + vertices, faces = icosphere(resolution) + # TODO Actually make this work + h3_mapping = {} + h_index = len(vertices) + for h in vertices: + if h not in h3_mapping: + h_index -= 1 + h3_mapping[h] = h_index + num_latlons + # Now have the h3 grid mapping, the bipartite graph of edges connecting lat/lon to h3 nodes + # Should have vertical and horizontal difference + h3_distances = [] + for idx, h3_point in enumerate(h3_grid): + lat_lon = lat_lons[idx] + distance = h3.point_dist(lat_lon, h3.h3_to_geo(h3_point), unit="rads") + h3_distances.append([np.sin(distance), np.cos(distance)]) + h3_distances = torch.tensor(h3_distances, dtype=torch.float) + return h3_mapping, h3_distances + +def generate_latent_ico_graph(h3_mapping, h3_distances): + """ + Generate latent h3 graph. + + Args: + base_h3_map: Mapping from h3 index to index in latent graph + h3_mapping: Mapping from lat/lon to h3 index + h3_distances: Distances between lat/lon and h3 index + + Returns: + Latent h3 graph + """ + # Get connectivity of the graph + edge_sources = [] + edge_targets = [] + edge_attrs = [] + for h3_index in h3_mapping: + h_points = h3.k_ring(h3_index, 1) + for h in h_points: # Already includes itself + distance = h3.point_dist(h3.h3_to_geo(h3_index), h3.h3_to_geo(h), unit="rads") + edge_attrs.append([np.sin(distance), np.cos(distance)]) + edge_sources.append(h3_mapping[h3_index]) + edge_targets.append(h3_mapping[h]) + edge_sources = np.array(edge_sources) + edge_targets = np.array(edge_targets) + edge_attrs = np.array(edge_attrs) + return edge_sources, edge_targets, edge_attrs \ No newline at end of file diff --git a/graph_weather/models/layers/encoder.py b/graph_weather/models/layers/encoder.py index 916a6523..f4c82f80 100644 --- a/graph_weather/models/layers/encoder.py +++ b/graph_weather/models/layers/encoder.py @@ -48,6 +48,8 @@ def __init__( hidden_layers_processor_edge=2, mlp_norm_type="LayerNorm", use_checkpointing: bool = False, + input_graph: Data = None, + latent_graph: Data = None, ): """ Encode the lat/lon data inot the isohedron graph @@ -65,6 +67,8 @@ def __init__( mlp_norm_type: Type of norm for the MLPs one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None use_checkpointing: Whether to use gradient checkpointing to use less memory + input_graph: Input graph to use, if None, will generate a default one + latent_graph: Latent graph to use, if None, will generate a default one """ super().__init__() self.use_checkpointing = use_checkpointing diff --git a/graph_weather/models/losses.py b/graph_weather/models/losses.py index 5e5d27ec..c417e3b6 100644 --- a/graph_weather/models/losses.py +++ b/graph_weather/models/losses.py @@ -61,3 +61,115 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor): out = out * self.weights.expand_as(out) assert not torch.isnan(out).any() return out.mean() + + +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn +from torch.autograd.function import once_differentiable + + +class CellAreaWeightedLossFunction(nn.Module): + """Loss function with cell area weighting. + + Parameters + ---------- + area : torch.Tensor + Cell area with shape [H, W]. + """ + + def __init__(self, area): + super().__init__() + self.area = area + + def forward(self, invar, outvar): + """ + Implicit forward function which computes the loss given + a prediction and the corresponding targets. + + Parameters + ---------- + invar : torch.Tensor + prediction of shape [T, C, H, W]. + outvar : torch.Tensor + target values of shape [T, C, H, W]. + """ + + loss = (invar - outvar) ** 2 + loss = loss.mean(dim=(0, 1)) + loss = torch.mul(loss, self.area) + loss = loss.mean() + return loss + + +class CustomCellAreaWeightedLossAutogradFunction(torch.autograd.Function): + """Autograd fuunction for custom loss with cell area weighting.""" + + @staticmethod + def forward(ctx, invar: torch.Tensor, outvar: torch.Tensor, area: torch.Tensor): + """Forward of custom loss function with cell area weighting.""" + + diff = invar - outvar # T x C x H x W + loss = diff**2 + loss = loss.mean(dim=(0, 1)) + loss = torch.mul(loss, area) + loss = loss.mean() + loss_grad = diff * (2.0 / (math.prod(invar.shape))) + loss_grad *= area.unsqueeze(0).unsqueeze(0) + ctx.save_for_backward(loss_grad) + return loss + + @staticmethod + @once_differentiable + def backward(ctx, grad_loss: torch.Tensor): + """Backward method of custom loss function with cell area weighting.""" + + # grad_loss should be 1, multiply nevertheless + # to avoid issues with cases where this isn't the case + (grad_invar,) = ctx.saved_tensors + return grad_invar * grad_loss, None, None + + +class CustomCellAreaWeightedLossFunction(CellAreaWeightedLossFunction): + """Custom loss function with cell area weighting. + + Parameters + ---------- + area : torch.Tensor + Cell area with shape [H, W]. + """ + + def __init__(self, area: torch.Tensor): + super().__init__(area) + + def forward(self, invar: torch.Tensor, outvar: torch.Tensor) -> torch.Tensor: + """ + Implicit forward function which computes the loss given + a prediction and the corresponding targets. + + Parameters + ---------- + invar : torch.Tensor + prediction of shape [T, C, H, W]. + outvar : torch.Tensor + target values of shape [T, C, H, W]. + """ + + return CustomCellAreaWeightedLossAutogradFunction.apply( + invar, outvar, self.area + ) \ No newline at end of file