diff --git a/.github/workflows/workflows.yaml b/.github/workflows/workflows.yaml index e8bf8322..3388db80 100644 --- a/.github/workflows/workflows.yaml +++ b/.github/workflows/workflows.yaml @@ -9,13 +9,13 @@ jobs: fail-fast: true matrix: os: [ubuntu-latest] - python-version: [3.8, 3.9, "3.10"] - torch-version: [1.13.0, 2.0.0] + python-version: ["3.10", "3.11"] + torch-version: [2.0.0, 2.1.0] include: - - torch-version: 1.13.0 - torchvision-version: 0.14.0 - torch-version: 2.0.0 torchvision-version: 0.15.1 + - torch-version: 2.1.0 + torchvision-version: 0.16.0 steps: - uses: actions/checkout@v2 diff --git a/graph_weather/data/utils.py b/graph_weather/data/utils.py new file mode 100644 index 00000000..e69de29b diff --git a/graph_weather/models/graphs/__init__.py b/graph_weather/models/graphs/__init__.py new file mode 100644 index 00000000..9117eb10 --- /dev/null +++ b/graph_weather/models/graphs/__init__.py @@ -0,0 +1 @@ +"""Set of graph classes for generating different meshes""" diff --git a/graph_weather/models/graphs/hexagonal.py b/graph_weather/models/graphs/hexagonal.py new file mode 100644 index 00000000..fb0dfc1e --- /dev/null +++ b/graph_weather/models/graphs/hexagonal.py @@ -0,0 +1,76 @@ +"""Generate hexagonal global grid using Uber's H3 library.""" +import h3 +import numpy as np +import torch +from torch_geometric.data import Data + + +def generate_hexagonal_grid(resolution: int = 2) -> np.ndarray: + """Generate hexagonal global grid using Uber's H3 library. + + Args: + resolution: H3 resolution level + + Returns: + Hexagonal grid + """ + base_h3_grid = sorted(list(h3.uncompact(h3.get_res0_indexes(), resolution))) + base_h3_map = {h_i: i for i, h_i in enumerate(base_h3_grid)} + return np.array(base_h3_grid), base_h3_map + + +def generate_h3_mapping(lat_lons: list, resolution: int = 2) -> dict: + """Generate mapping from lat/lon to h3 index. + + Args: + lat_lons: List of (lat,lon) points + resolution: H3 resolution level + """ + num_latlons = len(lat_lons) + base_h3_grid = sorted(list(h3.uncompact(h3.get_res0_indexes(), resolution))) + base_h3_map = {h_i: i for i, h_i in enumerate(base_h3_grid)} + h3_grid = [h3.geo_to_h3(lat, lon, resolution) for lat, lon in lat_lons] + h3_mapping = {} + h_index = len(base_h3_grid) + for h in base_h3_grid: + if h not in h3_mapping: + h_index -= 1 + h3_mapping[h] = h_index + num_latlons + # Now have the h3 grid mapping, the bipartite graph of edges connecting lat/lon to h3 nodes + # Should have vertical and horizontal difference + h3_distances = [] + for idx, h3_point in enumerate(h3_grid): + lat_lon = lat_lons[idx] + distance = h3.point_dist(lat_lon, h3.h3_to_geo(h3_point), unit="rads") + h3_distances.append([np.sin(distance), np.cos(distance)]) + h3_distances = torch.tensor(h3_distances, dtype=torch.float) + return base_h3_map, h3_mapping, h3_distances + + +def generate_latent_h3_graph(base_h3_map: dict, base_h3_grid: dict) -> torch.Tensor: + """Generate latent h3 graph. + + Args: + base_h3_map: Mapping from h3 index to index in latent graph + h3_mapping: Mapping from lat/lon to h3 index + h3_distances: Distances between lat/lon and h3 index + + Returns: + Latent h3 graph + """ + # Get connectivity of the graph + edge_sources = [] + edge_targets = [] + edge_attrs = [] + for h3_index in base_h3_grid: + h_points = h3.k_ring(h3_index, 1) + for h in h_points: # Already includes itself + distance = h3.point_dist(h3.h3_to_geo(h3_index), h3.h3_to_geo(h), unit="rads") + edge_attrs.append([np.sin(distance), np.cos(distance)]) + edge_sources.append(base_h3_map[h3_index]) + edge_targets.append(base_h3_map[h]) + edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long) + edge_attrs = torch.tensor(edge_attrs, dtype=torch.float) + # Use heterogeneous graph as input and output dims are not same for the encoder + # Because uniform grid now, don't need edge attributes as they are all the same + return Data(edge_index=edge_index, edge_attr=edge_attrs) diff --git a/graph_weather/models/graphs/ico.py b/graph_weather/models/graphs/ico.py new file mode 100644 index 00000000..c9e96a59 --- /dev/null +++ b/graph_weather/models/graphs/ico.py @@ -0,0 +1,392 @@ +""" +Creating geodesic icosahedron with given (integer) subdivision frequency (and +not by recursively applying Loop-like subdivision). + +Advantage of subdivision frequency compared to the recursive subdivision is in +controlling the mesh resolution. Mesh resolution grows quadratically with +subdivision frequencies while it grows exponentially with iterations of the +recursive subdivision. To be precise, using the recursive +subdivision (each iteration being a subdivision with frequency nu=2), the +possible number of vertices grows with iterations i as + [12+10*(2**i+1)*(2**i-1) for i in range(10)] +which gives + [12, 42, 162, 642, 2562, 10242, 40962, 163842, 655362, 2621442]. +Notice for example there is no mesh having between 2562 and 10242 vertices. +Using subdivision frequency, possible number of vertices grows with nu as + [12+10*(nu+1)*(nu-1) for nu in range(1,33)] +which gives + [12, 42, 92, 162, 252, 362, 492, 642, 812, 1002, 1212, 1442, 1692, 1962, + 2252, 2562, 2892, 3242, 3612, 4002, 4412, 4842, 5292, 5762, 6252, 6762, + 7292, 7842, 8412, 9002, 9612, 10242] +where nu = 32 gives 10242 vertices, and there are 15 meshes having between +2562 and 10242 vertices. The advantage is even more pronounced when using +higher resolutions. + +Author: vand@dtu.dk, 2014, 2017, 2021. +Originally developed in connectiton with +https://ieeexplore.ieee.org/document/7182720 + +This code is copied in as there is an improvement in the inside_points function that +is not merged in that speeds up generation 5-8x. See https://github.com/vedranaa/icosphere/pull/3 + +""" + +import numpy as np +from typing import Tuple +from torch_geometric.data import Data, HeteroData +from graph_weather.models.graphs.utils import generate_grid_to_mesh, generate_mesh_to_grid + + +def icosphere(nu=1, nr_verts=None): + """ + Returns a geodesic icosahedron with subdivision frequency nu. Frequency + nu = 1 returns regular unit icosahedron, and nu>1 preformes subdivision. + If nr_verts is given, nu will be adjusted such that icosphere contains + at least nr_verts vertices. Returned faces are zero-indexed! + + Parameters + ---------- + nu : subdivision frequency, integer (larger than 1 to make a change). + nr_verts: desired number of mesh vertices, if given, nu may be increased. + + + Returns + ------- + subvertices : vertex list, numpy array of shape (20+10*(nu+1)*(nu-1)/2, 3) + subfaces : face list, numpy array of shape (10*n**2, 3) + + """ + + # Unit icosahedron + (vertices, faces) = icosahedron() + + # If nr_verts given, computing appropriate subdivision frequency nu. + # We know nr_verts = 12+10*(nu+1)(nu-1) + if not nr_verts is None: + nu_min = np.ceil(np.sqrt(max(1 + (nr_verts - 12) / 10, 1))) + nu = max(nu, nu_min) + + # Subdividing + if nu > 1: + (vertices, faces) = subdivide_mesh(vertices, faces, nu) + vertices = vertices / np.sqrt(np.sum(vertices**2, axis=1, keepdims=True)) + + return (vertices, faces) + + +def icosahedron(): + """' Regular unit icosahedron.""" + + # 12 principal directions in 3D space: points on an unit icosahedron + phi = (1 + np.sqrt(5)) / 2 + vertices = np.array( + [[0, 1, phi], [0, -1, phi], [1, phi, 0], [-1, phi, 0], [phi, 0, 1], [-phi, 0, 1]] + ) / np.sqrt(1 + phi**2) + vertices = np.r_[vertices, -vertices] + + # 20 faces + faces = np.array( + [ + [0, 5, 1], + [0, 3, 5], + [0, 2, 3], + [0, 4, 2], + [0, 1, 4], + [1, 5, 8], + [5, 3, 10], + [3, 2, 7], + [2, 4, 11], + [4, 1, 9], + [7, 11, 6], + [11, 9, 6], + [9, 8, 6], + [8, 10, 6], + [10, 7, 6], + [2, 11, 7], + [4, 9, 11], + [1, 8, 9], + [5, 10, 8], + [3, 7, 10], + ], + dtype=int, + ) + + return (vertices, faces) + + +def subdivide_mesh(vertices, faces, nu): + """ + Subdivides mesh by adding vertices on mesh edges and faces. Each edge + will be divided in nu segments. (For example, for nu=2 one vertex is added + on each mesh edge, for nu=3 two vertices are added on each mesh edge and + one vertex is added on each face.) If V and F are number of mesh vertices + and number of mesh faces for the input mesh, the subdivided mesh contains + V + F*(nu+1)*(nu-1)/2 vertices and F*nu^2 faces. + + Parameters + ---------- + vertices : vertex list, numpy array of shape (V,3) + faces : face list, numby array of shape (F,3). Zero indexed. + nu : subdivision frequency, integer (larger than 1 to make a change). + + Returns + ------- + subvertices : vertex list, numpy array of shape (V + F*(nu+1)*(nu-1)/2, 3) + subfaces : face list, numpy array of shape (F*n**2, 3) + + Author: vand at dtu.dk, 8.12.2017. Translated to python 6.4.2021 + + """ + + edges = np.r_[faces[:, :-1], faces[:, 1:], faces[:, [0, 2]]] + edges = np.unique(np.sort(edges, axis=1), axis=0) + F = faces.shape[0] + V = vertices.shape[0] + E = edges.shape[0] + subfaces = np.empty((F * nu**2, 3), dtype=int) + subvertices = np.empty((V + E * (nu - 1) + F * (nu - 1) * (nu - 2) // 2, 3)) + + subvertices[:V] = vertices + + # Dictionary for accessing edge index from indices of edge vertices. + edge_indices = dict() + for i in range(V): + edge_indices[i] = dict() + for i in range(E): + edge_indices[edges[i, 0]][edges[i, 1]] = i + edge_indices[edges[i, 1]][edges[i, 0]] = -i + + template = faces_template(nu) + ordering = vertex_ordering(nu) + reordered_template = ordering[template] + + # At this point, we have V vertices, and now we add (nu-1) vertex per edge + # (on-edge vertices). + w = np.arange(1, nu) / nu # interpolation weights + for e in range(E): + edge = edges[e] + for k in range(nu - 1): + subvertices[V + e * (nu - 1) + k] = ( + w[-1 - k] * vertices[edge[0]] + w[k] * vertices[edge[1]] + ) + + # At this point we have E(nu-1)+V vertices, and we add (nu-1)*(nu-2)/2 + # vertices per face (on-face vertices). + r = np.arange(nu - 1) + for f in range(F): + # First, fixing connectivity. We get hold of the indices of all + # vertices invoved in this subface: original, on-edges and on-faces. + T = np.arange( + f * (nu - 1) * (nu - 2) // 2 + E * (nu - 1) + V, + (f + 1) * (nu - 1) * (nu - 2) // 2 + E * (nu - 1) + V, + ) # will be added + eAB = edge_indices[faces[f, 0]][faces[f, 1]] + eAC = edge_indices[faces[f, 0]][faces[f, 2]] + eBC = edge_indices[faces[f, 1]][faces[f, 2]] + AB = reverse(abs(eAB) * (nu - 1) + V + r, eAB < 0) # already added + AC = reverse(abs(eAC) * (nu - 1) + V + r, eAC < 0) # already added + BC = reverse(abs(eBC) * (nu - 1) + V + r, eBC < 0) # already added + VEF = np.r_[faces[f], AB, AC, BC, T] + subfaces[f * nu**2 : (f + 1) * nu**2, :] = VEF[reordered_template] + # Now geometry, computing positions of face vertices. + subvertices[T, :] = inside_points(subvertices[AB, :], subvertices[AC, :]) + + return (subvertices, subfaces) + + +def reverse(vector, flag): + """' For reversing the direction of an edge.""" + + if flag: + vector = vector[::-1] + return vector + + +def faces_template(nu): + """ + Template for linking subfaces 0 + in a subdivision of a face. / \ + Returns faces with vertex 1---2 + indexing given by reading order / \ / \ + (as illustratated). 3---4---5 + / \ / \ / \ + 6---7---8---9 + / \ / \ / \ / \ + 10--11--12--13--14 + """ + + faces = [] + # looping in layers of triangles + for i in range(nu): + vertex0 = i * (i + 1) // 2 + skip = i + 1 + for j in range(i): # adding pairs of triangles, will not run for i==0 + faces.append([j + vertex0, j + vertex0 + skip, j + vertex0 + skip + 1]) + faces.append([j + vertex0, j + vertex0 + skip + 1, j + vertex0 + 1]) + # adding the last (unpaired, rightmost) triangle + faces.append([i + vertex0, i + vertex0 + skip, i + vertex0 + skip + 1]) + + return np.array(faces) + + +def vertex_ordering(nu): + """ + Permutation for ordering of 0 + face vertices which transformes / \ + reading-order indexing into indexing 3---6 + first corners vertices, then on-edges / \ / \ + vertices, and then on-face vertices 4---12--7 + (as illustrated). / \ / \ / \ + 5---13--14--8 + / \ / \ / \ / \ + 1---9--10--11---2 + """ + + left = [j for j in range(3, nu + 2)] + right = [j for j in range(nu + 2, 2 * nu + 1)] + bottom = [j for j in range(2 * nu + 1, 3 * nu)] + inside = [j for j in range(3 * nu, (nu + 1) * (nu + 2) // 2)] + + o = [0] # topmost corner + for i in range(nu - 1): + o.append(left[i]) + o = o + inside[i * (i - 1) // 2 : i * (i + 1) // 2] + o.append(right[i]) + o = o + [1] + bottom + [2] + + return np.array(o) + + +def inside_points(vAB, vAC): + """ + Returns coordinates of the inside . + (on-face) vertices (marked by star) / \ + for subdivision of the face ABC when vAB0---vAC0 + given coordinates of the on-edge / \ / \ + vertices AB[i] and AC[i]. vAB1---*---vAC1 + / \ / \ / \ + vAB2---*---*---vAC2 + / \ / \ / \ / \ + .---.---.---.---. + """ + out = [] + u = vAB.shape[0] + for i in range(0 if u == 1 else 1, u): + # Linearly interpolate between vABi and vACi in `i + 1` (`j`) steps, + # not including the endpoints. + # This could be written as + # vABi = vAB[i, :] + # vACi = vAC[i, :] + # interp_multipliers = np.arange(1, j) / j + # res = np.outer(interp_multipliers, vACi) + np.outer(1 - interp_multipliers, vABi) + # but that will involve some extra work on `np.outer`'s part that we can + # do ourselves since we know the shapes we're working with. + j = i + 1 + interp_multipliers = (np.arange(1, j) / j)[:, None] + out.append( + np.multiply(interp_multipliers, vAC[i, None]) + + np.multiply(1 - interp_multipliers, vAB[i, None]) + ) + return np.concatenate(out) + + +def generate_icosphere_graph(resolution=1): + """ + Generate a graph of the icosphere with the given level of subdivision. + """ + vertices, faces = icosphere(resolution) + edges = np.r_[faces[:, :-1], faces[:, 1:], faces[:, [0, 2]]] + edges = np.unique(np.sort(edges, axis=1), axis=0) + return vertices, edges + + +def generate_icosphere_mapping( + lat_lons, resolutions=(1, 2, 4, 8, 16), bidirectional=True +) -> Tuple[HeteroData, Data, HeteroData]: + """ + Generate mapping from lat/lon to icosphere index. + + GraphCast maps from lat/lon directly to all the different icosphere levels + And maps back from all te different icosphere levels to lat/lon + Nodes are shared to an extant, in that the base nodes (12) are connected to all the other layers and has more + incoming edges than each of the further refined ones. In the most extreme case, the finest resolution + icosphere nodes have only the incoming edges from the nearby nodes. For 6 different layers, each layer adds 5 + more incoming edges to the nodes in the above layers. So the top layer nodes have 30 incoming edges, for the finest + one, it has 5 incoming edges. This is a bipartite graph, so the edges are between the lat/lon nodes and the icosphere + + Args: + lat_lons: List of (lat,lon) points + resolutions: Icosphere resolution levels, first 7 levels correspond to Graphcast levels, in ascending order of resolution + """ + import numpy as np + import torch + from torch_geometric.data import Data + from graph_weather.models.graphs.utils import deg2rad, latlon2xyz, xyz2latlon, get_edge_len + from sklearn.neighbors import NearestNeighbors + + num_latlons = len(lat_lons) + verticies_per_level = [] + edges_per_level = [] + for resolution in resolutions: + vertices, edges = generate_icosphere_graph(resolution) + verticies_per_level.append(vertices) + edges_per_level.append(edges) + + # Check the verticies of each pair are the same up to the resolution + for i in range(len(verticies_per_level) - 2): + # print(edges_per_level[i]) + for vertex_lower_index, vertex in enumerate(verticies_per_level[i]): + # Go through each index in the current level, finding the closest vertex in the next level + # Should check all verticies in the next level, and find the closest one + for vertex_upper_index, upper_vertex in enumerate(verticies_per_level[-1]): + multiple_equals = 0 + if np.all(vertex == upper_vertex): + multiple_equals += 1 + # Manually go through and update the edge tuples to have the vertex_upper_index instead of vertex_lower_index + for edge_index, edge in enumerate(edges_per_level[i]): + if edge[0] == vertex_lower_index: + edges_per_level[i][edge_index][0] = vertex_upper_index + if edge[1] == vertex_lower_index: + edges_per_level[i][edge_index][1] = vertex_upper_index + # print(edges_per_level[i][edges_per_level[i] == vertex_lower_index]) + # The vertex is the same, so the edges in the current level that equal vertex_lower_index should be changed to equal vertex_upper_index + # edges_per_level[i][edges_per_level[i] == vertex_upper_index] = vertex_upper_index + if multiple_equals > 1: + print(f"Multiple equals: {multiple_equals}") + # print(edges_per_level[i]) + # print("------------------") + verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above + edges = np.sort(np.concatenate(edges_per_level), axis=1) + print(f"Number of edges: {len(edges)}") + if bidirectional: + # Need to add the flipped version of the elements in edges to edges to have the bidirectional graph + edges = np.concatenate([edges, edges[:, [1, 0]]], axis=0) + # TODO Create mapping from the lat/lon to the icosphere nodes + print(f"Now Number of edges: {len(edges)}") + print(f"Number of unique edges: {len(np.unique(edges, axis=0))}") + print( + f"Max number of repeated edges: {np.max(np.unique(edges, axis=0, return_counts=True)[1])}" + ) + u, c = np.unique(edges, axis=0, return_counts=True) + print(f"First 20 duplicates: {u[c > 1][:20]}") + # Features will need to be in the same lat/lon order as given, and added to the verticies + ico_graph = Data( + pos=torch.tensor(verticies, dtype=torch.float), + edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous(), + ) + max_edge_len = np.max( + get_edge_len( + ico_graph.pos[edges_per_level[-1][:, 0]], ico_graph.pos[edges_per_level[-1][:, 1]] + ) + ) + # Check that the graph is valid + ico_graph.validate(raise_on_error=True) + + # Generate grid to mesh and mesh to graph + grid_to_mesh = generate_grid_to_mesh(lat_lons, ico_graph, max_edge_length=max_edge_len) + mesh_to_grid = generate_mesh_to_grid(lat_lons, ico_graph) + + return grid_to_mesh, ico_graph, mesh_to_grid + + +generate_icosphere_mapping([(0, 0), (0, 1), (1, 0), (1, 1)]) diff --git a/graph_weather/models/graphs/utils.py b/graph_weather/models/graphs/utils.py new file mode 100644 index 00000000..3d651b06 --- /dev/null +++ b/graph_weather/models/graphs/utils.py @@ -0,0 +1,439 @@ +""" +Most of this is copied from https://github.com/NVIDIA/modulus/blob/main/modulus/utils/graphcast/graph_utils.py + +There are some adaptions for PyTorch Geometric data structures instead. + +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + +import torch +from typing import Optional, Tuple +from torch import Tensor, testing +import numpy as np +from torch_geometric.data import Data, HeteroData +from sklearn.neighbors import NearestNeighbors + + +def latlon2xyz(latlon: Tensor, radius: float = 1, unit: str = "deg") -> Tensor: + """ + Converts latlon in degrees to xyz + Based on: https://stackoverflow.com/questions/1185408 + - The x-axis goes through long,lat (0,0); + - The y-axis goes through (0,90); + - The z-axis goes through the poles. + + Parameters + ---------- + latlon : Tensor + Tensor of shape (N, 2) containing latitudes and longitudes + radius : float, optional + Radius of the sphere, by default 1 + unit : str, optional + Unit of the latlon, by default "deg" + + Returns + ------- + Tensor + Tensor of shape (N, 3) containing x, y, z coordinates + """ + if unit == "deg": + latlon = deg2rad(latlon) + elif unit == "rad": + pass + else: + raise ValueError("Not a valid unit") + lat, lon = latlon[:, 0], latlon[:, 1] + x = radius * torch.cos(lat) * torch.cos(lon) + y = radius * torch.cos(lat) * torch.sin(lon) + z = radius * torch.sin(lat) + return torch.stack((x, y, z), dim=1) + + +def xyz2latlon(xyz: Tensor, radius: float = 1, unit: str = "deg") -> Tensor: + """ + Converts xyz to latlon in degrees + Based on: https://stackoverflow.com/questions/1185408 + - The x-axis goes through long,lat (0,0); + - The y-axis goes through (0,90); + - The z-axis goes through the poles. + + Parameters + ---------- + xyz : Tensor + Tensor of shape (N, 3) containing x, y, z coordinates + radius : float, optional + Radius of the sphere, by default 1 + unit : str, optional + Unit of the latlon, by default "deg" + + Returns + ------- + Tensor + Tensor of shape (N, 2) containing latitudes and longitudes + """ + lat = torch.arcsin(xyz[:, 2] / radius) + lon = torch.arctan2(xyz[:, 1], xyz[:, 0]) + if unit == "deg": + return torch.stack((rad2deg(lat), rad2deg(lon)), dim=1) + elif unit == "rad": + return torch.stack((lat, lon), dim=1) + else: + raise ValueError("Not a valid unit") + + +def get_edge_len(edge_src: Tensor, edge_dst: Tensor, axis: int = 1): + """returns the length of the edge + + Parameters + ---------- + edge_src : Tensor + Tensor of shape (N, 3) containing the source of the edge + edge_dst : Tensor + Tensor of shape (N, 3) containing the destination of the edge + axis : int, optional + Axis along which the norm is computed, by default 1 + + Returns + ------- + Tensor + Tensor of shape (N, ) containing the length of the edge + """ + return np.linalg.norm(edge_src - edge_dst, axis=axis) + + +def deg2rad(deg: Tensor) -> Tensor: + """Converts degrees to radians + + Parameters + ---------- + deg : + Tensor of shape (N, ) containing the degrees + + Returns + ------- + Tensor + Tensor of shape (N, ) containing the radians + """ + return deg * np.pi / 180 + + +def rad2deg(rad): + """Converts radians to degrees + + Parameters + ---------- + rad : + Tensor of shape (N, ) containing the radians + + Returns + ------- + Tensor + Tensor of shape (N, ) containing the degrees + """ + return rad * 180 / np.pi + + +def azimuthal_angle(lon: Tensor) -> Tensor: + """ + Gives the azimuthal angle of a point on the sphere + + Parameters + ---------- + lon : Tensor + Tensor of shape (N, ) containing the longitude of the point + + Returns + ------- + Tensor + Tensor of shape (N, ) containing the azimuthal angle + """ + angle = torch.where(lon >= 0.0, 2 * np.pi - lon, -lon) + return angle + + +def polar_angle(lat: Tensor) -> Tensor: + """ + Gives the polar angle of a point on the sphere + + Parameters + ---------- + lat : Tensor + Tensor of shape (N, ) containing the latitude of the point + + Returns + ------- + Tensor + Tensor of shape (N, ) containing the polar angle + """ + angle = torch.where(lat >= 0.0, lat, 2 * np.pi + lat) + return angle + + +def geospatial_rotation(invar: Tensor, theta: Tensor, axis: str, unit: str = "rad") -> Tensor: + """Rotation using right hand rule + + Parameters + ---------- + invar : Tensor + Tensor of shape (N, 3) containing x, y, z coordinates + theta : Tensor + Tensor of shape (N, ) containing the rotation angle + axis : str + Axis of rotation + unit : str, optional + Unit of the theta, by default "rad" + + Returns + ------- + Tensor + Tensor of shape (N, 3) containing the rotated x, y, z coordinates + """ + + # get the right unit + if unit == "deg": + invar = rad2deg(invar) + elif unit == "rad": + pass + else: + raise ValueError("Not a valid unit") + + invar = torch.unsqueeze(invar, -1) + rotation = torch.zeros((theta.size(0), 3, 3)) + cos = torch.cos(theta) + sin = torch.sin(theta) + + if axis == "x": + rotation[:, 0, 0] += 1.0 + rotation[:, 1, 1] += cos + rotation[:, 1, 2] -= sin + rotation[:, 2, 1] += sin + rotation[:, 2, 2] += cos + elif axis == "y": + rotation[:, 0, 0] += cos + rotation[:, 0, 2] += sin + rotation[:, 1, 1] += 1.0 + rotation[:, 2, 0] -= sin + rotation[:, 2, 2] += cos + elif axis == "z": + rotation[:, 0, 0] += cos + rotation[:, 0, 1] -= sin + rotation[:, 1, 0] += sin + rotation[:, 1, 1] += cos + rotation[:, 2, 2] += 1.0 + else: + raise ValueError("Invalid axis") + + outvar = torch.matmul(rotation, invar) + outvar = outvar.squeeze() + return outvar + + +def add_edge_features(graph: Data, pos: Tensor, normalize: bool = True) -> Data: + """Adds edge features to the graph. + + Parameters + ---------- + graph : Data + The graph to add edge features to. + pos : Tensor + The node positions. + normalize : bool, optional + Whether to normalize the edge features, by default True + + Returns + ------- + DGLGraph + The graph with edge features. + """ + + if isinstance(pos, tuple): + src_pos, dst_pos = pos + else: + src_pos = dst_pos = pos + src, dst = graph.edge_index + + src_pos, dst_pos = src_pos[src.long()], dst_pos[dst.long()] + dst_latlon = xyz2latlon(dst_pos, unit="rad") + dst_lat, dst_lon = dst_latlon[:, 0], dst_latlon[:, 1] + + # azimuthal & polar rotation + theta_azimuthal = azimuthal_angle(dst_lon) + theta_polar = polar_angle(dst_lat) + + src_pos = geospatial_rotation(src_pos, theta=theta_azimuthal, axis="z", unit="rad") + dst_pos = geospatial_rotation(dst_pos, theta=theta_azimuthal, axis="z", unit="rad") + # y values should be zero + try: + testing.assert_close(dst_pos[:, 1], torch.zeros_like(dst_pos[:, 1])) + except ValueError: + raise ValueError("Invalid projection of edge nodes to local ccordinate system") + src_pos = geospatial_rotation(src_pos, theta=theta_polar, axis="y", unit="rad") + dst_pos = geospatial_rotation(dst_pos, theta=theta_polar, axis="y", unit="rad") + # x values should be one, y & z values should be zero + try: + testing.assert_close(dst_pos[:, 0], torch.ones_like(dst_pos[:, 0])) + testing.assert_close(dst_pos[:, 1], torch.zeros_like(dst_pos[:, 1])) + testing.assert_close(dst_pos[:, 2], torch.zeros_like(dst_pos[:, 2])) + except ValueError: + raise ValueError("Invalid projection of edge nodes to local ccordinate system") + + # prepare edge features + disp = src_pos - dst_pos + disp_norm = torch.linalg.norm(disp, dim=-1, keepdim=True) + + # normalize using the longest edge + if normalize: + max_disp_norm = torch.max(disp_norm) + graph["edge_attr"] = torch.cat((disp / max_disp_norm, disp_norm / max_disp_norm), dim=-1) + else: + graph["edge_attr"] = torch.cat((disp, disp_norm), dim=-1) + return graph + + +def add_node_features(graph: Data, pos: Tensor) -> Data: + """Adds cosine of latitude, sine and cosine of longitude as the node features + to the graph. + + Parameters + ---------- + graph : Data + The graph to add node features to. + pos : Tensor + The node positions. + + Returns + ------- + graph : DGLGraph + The graph with node features. + """ + latlon = xyz2latlon(pos) + lat, lon = latlon[:, 0], latlon[:, 1] + graph["x"] = torch.stack((torch.cos(lat), torch.sin(lon), torch.cos(lon)), dim=-1) + return graph + + +def generate_grid_to_mesh( + lat_lons: torch.Tensor, mesh: Data, max_edge_length: Optional[float] = None +) -> HeteroData: + if max_edge_length is None: + max_edge_len = np.max( + get_edge_len(mesh.pos[mesh.edge_index[:, 0]], mesh.pos[mesh.edge_index[:, 1]]) + ) + else: + max_edge_len = max_edge_length + + # create the grid2mesh bipartite graph + cartesian_grid = latlon2xyz(lat_lons) + n_nbrs = 4 + neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(mesh.pos) + distances, indices = neighbors.kneighbors(cartesian_grid) + + src, dst = [], [] + for i in range(len(cartesian_grid)): + for j in range(n_nbrs): + if distances[i][j] <= 0.6 * max_edge_len: + src.append(i) + dst.append(indices[i][j]) + # This is in COO format now, and it is not bidirectional, so no copying + grid2mesh = HeteroData() + grid2mesh["grid"].pos = torch.tensor(cartesian_grid, dtype=torch.float) + grid2mesh["mesh"].pos = mesh.pos + grid2mesh["grid", "g2m", "mesh"].edge_index = torch.tensor([src, dst], dtype=torch.long) + # Add edge features + grid2mesh = add_edge_features(grid2mesh, (grid2mesh["grid"].pos, grid2mesh["mesh"].pos)) + return grid2mesh + + +def generate_mesh_to_grid(lat_lons: torch.Tensor, mesh: Data): + # create the mesh2grid bipartite graph + cartesian_grid = latlon2xyz(lat_lons) + n_nbrs = 1 + neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(mesh.pos) + _, indices = neighbors.kneighbors(cartesian_grid) + indices = indices.flatten() + + src = [ + p for i in indices for p in mesh.pos[i] # TODO Need to fix this to be the faces in the mesh + ] + dst = [i for i in range(len(cartesian_grid)) for _ in range(3)] + + mesh2grid = HeteroData() + mesh2grid["mesh"].pos = mesh.pos + mesh2grid["grid"].pos = torch.tensor(cartesian_grid, dtype=torch.float) + mesh2grid["mesh", "m2g", "grid"].edge_index = torch.tensor([src, dst], dtype=torch.long) + # Add edge features + mesh2grid = add_edge_features(mesh2grid, (mesh2grid["mesh"].pos, mesh2grid["grid"].pos)) + + return mesh2grid + + +def plot_graph(graph: Data, **kwargs): + """Plots the graph. + + Parameters + ---------- + graph : Data + The graph to plot. + """ + import matplotlib.pyplot as plt + import networkx as nx + from torch_geometric.utils import to_networkx + + G = to_networkx(graph, to_undirected=True) + import networkx as nx + import numpy as np + import matplotlib.pyplot as plt + from mpl_toolkits.mplot3d import Axes3D + + # 3d spring layout + fixed_positions = {v: graph.pos[v] for v in range(len(graph.pos))} + fixed_nodes = fixed_positions.keys() + pos = nx.spring_layout(G, dim=3, fixed=fixed_nodes, pos=fixed_positions) + # Extract node and edge positions from the layout + node_xyz = np.array([pos[v] for v in sorted(G)]) + edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()]) + + # Create the 3D figure + fig = plt.figure() + ax = fig.add_subplot(111, projection="3d") + + # Plot the nodes - alpha is scaled by "depth" automatically + ax.scatter(*node_xyz.T, s=10, ec="w") + + # Plot the edges + for vizedge in edge_xyz: + ax.plot(*vizedge.T, color="tab:gray") + + def _format_axes(ax): + """Visualization options for the 3D axes.""" + # Turn gridlines off + ax.grid(False) + # Suppress tick labels + for dim in (ax.xaxis, ax.yaxis, ax.zaxis): + dim.set_ticks([]) + # Set axes labels + ax.set_xlabel("x") + ax.set_ylabel("y") + ax.set_zlabel("z") + + _format_axes(ax) + fig.tight_layout() + # rotate the axes and update + for angle in range(0, 360): + ax.view_init(30, angle) + plt.draw() + plt.pause(0.01) + plt.show() diff --git a/graph_weather/models/layers/encoder.py b/graph_weather/models/layers/encoder.py index 916a6523..f4c82f80 100644 --- a/graph_weather/models/layers/encoder.py +++ b/graph_weather/models/layers/encoder.py @@ -48,6 +48,8 @@ def __init__( hidden_layers_processor_edge=2, mlp_norm_type="LayerNorm", use_checkpointing: bool = False, + input_graph: Data = None, + latent_graph: Data = None, ): """ Encode the lat/lon data inot the isohedron graph @@ -65,6 +67,8 @@ def __init__( mlp_norm_type: Type of norm for the MLPs one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None use_checkpointing: Whether to use gradient checkpointing to use less memory + input_graph: Input graph to use, if None, will generate a default one + latent_graph: Latent graph to use, if None, will generate a default one """ super().__init__() self.use_checkpointing = use_checkpointing diff --git a/graph_weather/models/losses.py b/graph_weather/models/losses.py index 5e5d27ec..ab1a5e68 100644 --- a/graph_weather/models/losses.py +++ b/graph_weather/models/losses.py @@ -61,3 +61,113 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor): out = out * self.weights.expand_as(out) assert not torch.isnan(out).any() return out.mean() + + +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.nn as nn +from torch.autograd.function import once_differentiable + + +class CellAreaWeightedLossFunction(nn.Module): + """Loss function with cell area weighting. + + Parameters + ---------- + area : torch.Tensor + Cell area with shape [H, W]. + """ + + def __init__(self, area): + super().__init__() + self.area = area + + def forward(self, invar, outvar): + """ + Implicit forward function which computes the loss given + a prediction and the corresponding targets. + + Parameters + ---------- + invar : torch.Tensor + prediction of shape [T, C, H, W]. + outvar : torch.Tensor + target values of shape [T, C, H, W]. + """ + + loss = (invar - outvar) ** 2 + loss = loss.mean(dim=(0, 1)) + loss = torch.mul(loss, self.area) + loss = loss.mean() + return loss + + +class CustomCellAreaWeightedLossAutogradFunction(torch.autograd.Function): + """Autograd fuunction for custom loss with cell area weighting.""" + + @staticmethod + def forward(ctx, invar: torch.Tensor, outvar: torch.Tensor, area: torch.Tensor): + """Forward of custom loss function with cell area weighting.""" + + diff = invar - outvar # T x C x H x W + loss = diff**2 + loss = loss.mean(dim=(0, 1)) + loss = torch.mul(loss, area) + loss = loss.mean() + loss_grad = diff * (2.0 / (math.prod(invar.shape))) + loss_grad *= area.unsqueeze(0).unsqueeze(0) + ctx.save_for_backward(loss_grad) + return loss + + @staticmethod + @once_differentiable + def backward(ctx, grad_loss: torch.Tensor): + """Backward method of custom loss function with cell area weighting.""" + + # grad_loss should be 1, multiply nevertheless + # to avoid issues with cases where this isn't the case + (grad_invar,) = ctx.saved_tensors + return grad_invar * grad_loss, None, None + + +class CustomCellAreaWeightedLossFunction(CellAreaWeightedLossFunction): + """Custom loss function with cell area weighting. + + Parameters + ---------- + area : torch.Tensor + Cell area with shape [H, W]. + """ + + def __init__(self, area: torch.Tensor): + super().__init__(area) + + def forward(self, invar: torch.Tensor, outvar: torch.Tensor) -> torch.Tensor: + """ + Implicit forward function which computes the loss given + a prediction and the corresponding targets. + + Parameters + ---------- + invar : torch.Tensor + prediction of shape [T, C, H, W]. + outvar : torch.Tensor + target values of shape [T, C, H, W]. + """ + + return CustomCellAreaWeightedLossAutogradFunction.apply(invar, outvar, self.area) diff --git a/train/deepspeed_graph.py b/train/deepspeed_graph.py deleted file mode 100644 index 23ede888..00000000 --- a/train/deepspeed_graph.py +++ /dev/null @@ -1,64 +0,0 @@ -import pytorch_lightning as pl -import torch -from pytorch_lightning import Trainer - -from graph_weather import GraphWeatherForecaster - -lat_lons = [] -for lat in range(-90, 90, 1): - for lon in range(0, 360, 1): - lat_lons.append((lat, lon)) - - -class LitModel(pl.LightningModule): - def __init__(self, lat_lons, feature_dim, aux_dim): - super().__init__() - self.model = GraphWeatherForecaster( - lat_lons=lat_lons, feature_dim=feature_dim, aux_dim=aux_dim - ) - - def training_step(self, batch): - x, y = batch - x = x.half() - y = y.half() - out = self.forward(x) - criterion = torch.nn.MSELoss() - loss = criterion(out, y) - return loss - - def configure_optimizers(self): - return torch.optim.AdamW(self.parameters()) - - def forward(self, x): - return self.model(x) - - -# Fake data -from torch.utils.data import DataLoader, Dataset - - -class FakeDataset(Dataset): - def __init__(self): - super(FakeDataset, self).__init__() - - def __len__(self): - return 64000 - - def __getitem__(self, item): - return torch.randn((64800, 605 + 32)), torch.randn((64800, 605)) - - -model = LitModel(lat_lons=lat_lons, feature_dim=605, aux_dim=32) -trainer = Trainer( - accelerator="gpu", - devices=1, - strategy="deepspeed_stage_3_offload", - precision=16, - max_epochs=10, - limit_train_batches=2000, -) -dataset = FakeDataset() -train_dataloader = DataLoader( - dataset, batch_size=1, num_workers=1, pin_memory=True, prefetch_factor=1 -) -trainer.fit(model=model, train_dataloaders=train_dataloader) diff --git a/train/pl_graph_weather.py b/train/pl_graph_weather.py deleted file mode 100644 index 40ed7f21..00000000 --- a/train/pl_graph_weather.py +++ /dev/null @@ -1,355 +0,0 @@ -"""PyTorch Lightning training script for the weather forecasting model""" -import click -import datasets -import numpy as np -import pandas as pd -import pytorch_lightning as pl -import torch -from pysolar.util import extraterrestrial_irrad -from pytorch_lightning.callbacks import ModelCheckpoint -from torch.utils.data import DataLoader - -from graph_weather import GraphWeatherForecaster -from graph_weather.data import const -from graph_weather.models.losses import NormalizedMSELoss - -const.FORECAST_MEANS = {var: np.asarray(value) for var, value in const.FORECAST_MEANS.items()} -const.FORECAST_STD = {var: np.asarray(value) for var, value in const.FORECAST_STD.items()} - - -def worker_init_fn(worker_id): - np.random.seed(np.random.get_state()[1][0] + worker_id) - - -def get_mean_stds(): - names = [ - "CLMR", - "GRLE", - "VVEL", - "VGRD", - "UGRD", - "O3MR", - "CAPE", - "TMP", - "PLPL", - "DZDT", - "CIN", - "HGT", - "RH", - "ICMR", - "SNMR", - "SPFH", - "RWMR", - "TCDC", - "ABSV", - ] - means = {} - stds = {} - # For pressure level values - for n in names: - if ( - len( - sorted( - [ - float(var.split(".", 1)[-1].split("_")[0]) - for var in const.FORECAST_MEANS - if "mb" in var and n in var and "-" not in var - ] - ) - ) - > 0 - ): - means[n + "_mb"] = [] - stds[n + "_mb"] = [] - for value in sorted( - [ - float(var.split(".", 1)[-1].split("_")[0]) - for var in const.FORECAST_MEANS - if "mb" in var and n in var and "-" not in var - ] - ): - # Is floats now, but will be fixed - if value >= 1: - value = int(value) - var_name = f"{n}.{value}_mb" - # print(var_name) - - means[n + "_mb"].append(const.FORECAST_MEANS[var_name]) - stds[n + "_mb"].append(const.FORECAST_STD[var_name]) - means[n + "_mb"] = np.mean(np.stack(means[n + "_mb"], axis=-1)) - stds[n + "_mb"] = np.mean(np.stack(stds[n + "_mb"], axis=-1)) - - # For surface values - for n in list( - set( - [ - var.split(".", 1)[0] - for var in const.FORECAST_MEANS - if "surface" in var - and "level" not in var - and "2e06" not in var - and "below" not in var - and "atmos" not in var - and "tropo" not in var - and "iso" not in var - and "planetary_boundary_layer" not in var - ] - ) - ): - means[n] = const.FORECAST_MEANS[n + ".surface"] - stds[n] = const.FORECAST_STD[n + ".surface"] - - # For Cloud levels - for n in list( - set( - [ - var.split(".", 1)[0] - for var in const.FORECAST_MEANS - if "sigma" not in var - and "level" not in var - and "2e06" not in var - and "below" not in var - and "atmos" not in var - and "tropo" not in var - and "iso" not in var - and "planetary_boundary_layer" not in var - ] - ) - ): - if "LCDC" in n: # or "MCDC" in n or "HCDC" in n: - means[n] = const.FORECAST_MEANS["LCDC.low_cloud_layer"] - stds[n] = const.FORECAST_STD["LCDC.low_cloud_layer"] - if "MCDC" in n: # or "HCDC" in n: - means[n] = const.FORECAST_MEANS["MCDC.middle_cloud_layer"] - stds[n] = const.FORECAST_STD["MCDC.middle_cloud_layer"] - if "HCDC" in n: - means[n] = const.FORECAST_MEANS["HCDC.high_cloud_layer"] - stds[n] = const.FORECAST_STD["HCDC.high_cloud_layer"] - - # Now for each of these - means["max_wind"] = [] - stds["max_wind"] = [] - for n in sorted([var for var in const.FORECAST_MEANS if "max_wind" in var]): - means["max_wind"].append(const.FORECAST_MEANS[n]) - stds["max_wind"].append(const.FORECAST_STD[n]) - means["max_wind"] = np.stack(means["max_wind"], axis=-1) - stds["max_wind"] = np.stack(stds["max_wind"], axis=-1) - - for i in [2, 10, 20, 30, 40, 50, 80, 100]: - means[f"{i}m_above_ground"] = [] - stds[f"{i}m_above_ground"] = [] - for n in sorted([var for var in const.FORECAST_MEANS if f"{i}_m_above_ground" in var]): - means[f"{i}m_above_ground"].append(const.FORECAST_MEANS[n]) - stds[f"{i}m_above_ground"].append(const.FORECAST_STD[n]) - means[f"{i}m_above_ground"] = np.stack(means[f"{i}m_above_ground"], axis=-1) - stds[f"{i}m_above_ground"] = np.stack(stds[f"{i}m_above_ground"], axis=-1) - return means, stds - - -means, stds = get_mean_stds() - - -def process_data(data): - data.update( - { - key: np.expand_dims(np.asarray(value), axis=-1) - for key, value in data.items() - if key.replace("current_", "").replace("next_", "") in means.keys() - and np.asarray(value).ndim == 2 - } - ) # Add third dimension for ones with 2 - input_data = { - key.replace("current_", ""): torch.from_numpy( - (value - means[key.replace("current_", "")]) / stds[key.replace("current_", "")] - ) - for key, value in data.items() - if "current" in key and "time" not in key - } - output_data = { - key.replace("next_", ""): torch.from_numpy( - (value - means[key.replace("next_", "")]) / stds[key.replace("next_", "")] - ) - for key, value in data.items() - if "next" in key and "time" not in key - } - lat_lons = np.array( - np.meshgrid(np.asarray(data["latitude"]).flatten(), np.asarray(data["longitude"]).flatten()) - ).T.reshape((-1, 2)) - sin_lat_lons = np.sin(lat_lons * np.pi / 180.0) - cos_lat_lons = np.cos(lat_lons * np.pi / 180.0) - date = pd.to_datetime(data["timestamps"][0], utc=True) - solar_times = [ - np.array( - [ - extraterrestrial_irrad( - when=date.to_pydatetime(), latitude_deg=lat, longitude_deg=lon - ) - for lat, lon in lat_lons - ] - ) - ] - for when in pd.date_range( - date - pd.Timedelta("12 hours"), date + pd.Timedelta("12 hours"), freq="1H" - ): - solar_times.append( - np.array( - [ - extraterrestrial_irrad( - when=when.to_pydatetime(), latitude_deg=lat, longitude_deg=lon - ) - for lat, lon in lat_lons - ] - ) - ) - solar_times = np.array(solar_times) - # Normalize to between -1 and 1 - solar_times -= const.SOLAR_MEAN - solar_times /= const.SOLAR_STD - input_data = torch.concat([value for _, value in input_data.items()], dim=-1) - output_data = torch.concat([value for _, value in output_data.items()], dim=-1) - input_data = input_data.transpose(0, 1).reshape(-1, input_data.shape[-1]) - output_data = output_data.transpose(0, 1).reshape(-1, input_data.shape[-1]) - day_of_year = pd.to_datetime(data["timestamps"][0], utc=True).dayofyear / 366.0 - sin_of_year = np.ones_like(lat_lons)[:, 0] * np.sin(day_of_year) - cos_of_year = np.ones_like(lat_lons)[:, 0] * np.cos(day_of_year) - to_concat = [ - input_data, - torch.permute(torch.from_numpy(solar_times), (1, 0)), - torch.from_numpy(sin_lat_lons), - torch.from_numpy(cos_lat_lons), - torch.from_numpy(np.expand_dims(sin_of_year, axis=-1)), - torch.from_numpy(np.expand_dims(cos_of_year, axis=-1)), - ] # , landsea_fixed] - input_data = torch.concat(to_concat, dim=-1) - new_data = { - "input": input_data.float().numpy(), - "output": output_data.float().numpy(), - "has_nans": not np.isnan(input_data.float().numpy()).any() - and not np.isnan(output_data.float().numpy()).any(), - } - return new_data - - -class GraphDataModule(pl.LightningDataModule): - def __init__(self, deg: str = "2.0", batch_size: int = 1): - super().__init__() - self.batch_size = batch_size - self.dataset = datasets.load_dataset( - "openclimatefix/gfs-surface-pressure-2deg", split="train+validation", streaming=False - ) - features = datasets.Features( - { - "input": datasets.Array2D(shape=(16380, 637), dtype="float32"), - "output": datasets.Array2D(shape=(16380, 605), dtype="float32"), - "has_nans": datasets.Value("bool"), - } - ) - self.dataset = ( - self.dataset.map( - process_data, - remove_columns=self.dataset.column_names, - features=features, - num_proc=16, - writer_batch_size=2, - ) - .filter(lambda x: x["has_nans"]) - .with_format("torch") - ) - - def train_dataloader(self): - return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=2) - - -class LitGraphForecaster(pl.LightningModule): - def __init__( - self, - lat_lons: list, - feature_dim: int = 605, - aux_dim: int = 32, - hidden_dim: int = 64, - num_blocks: int = 3, - lr: float = 3e-4, - ): - super().__init__() - self.model = GraphWeatherForecaster( - lat_lons, - feature_dim=feature_dim, - aux_dim=aux_dim, - hidden_dim_decoder=hidden_dim, - hidden_dim_processor_node=hidden_dim, - hidden_dim_processor_edge=hidden_dim, - num_blocks=num_blocks, - ) - self.criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) - ) - self.lr = lr - self.save_hyperparameters() - - def forward(self, x): - return self.model(x) - - def training_step(self, batch, batch_idx): - x, y = batch["input"], batch["output"] - if torch.isnan(x).any() or torch.isnan(y).any(): - return None - y_hat = self.forward(x) - loss = self.criterion(y_hat, y) - return loss - - def configure_optimizers(self): - return torch.optim.AdamW(self.parameters(), lr=self.lr) - - -@click.command() -@click.option( - "--num-blocks", - default=5, - help="Where to save the zarr files", - type=click.INT, -) -@click.option( - "--hidden", - default=32, - help="Where to save the zarr files", - type=click.INT, -) -@click.option( - "--batch", - default=1, - help="Where to save the zarr files", - type=click.INT, -) -@click.option( - "--gpus", - default=1, - help="Where to save the zarr files", - type=click.INT, -) -def run(num_blocks, hidden, batch, gpus): - hf_ds = datasets.load_dataset( - "openclimatefix/gfs-surface-pressure-2deg", split="train", streaming=False - ) - example_batch = next(iter(hf_ds)) - lat_lons = np.array( - np.meshgrid( - np.asarray(example_batch["latitude"]).flatten(), - np.asarray(example_batch["longitude"]).flatten(), - ) - ).T.reshape((-1, 2)) - checkpoint_callback = ModelCheckpoint(dirpath="./", save_top_k=2, monitor="loss") - dset = GraphDataModule(batch_size=batch) - model = LitGraphForecaster(lat_lons=lat_lons, num_blocks=num_blocks, hidden_dim=hidden) - trainer = pl.Trainer( - accelerator="gpu", - devices=gpus, - max_epochs=100, - precision=16, - callbacks=[checkpoint_callback], - ) - # strategy="deepspeed_stage_2_offload") - trainer.fit(model, dset) - - -if __name__ == "__main__": - run() diff --git a/train/run_fulll.py b/train/run_fulll.py deleted file mode 100644 index fdf65bc5..00000000 --- a/train/run_fulll.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Training script for training the weather forecasting model""" -import json - -import numpy as np -import torch -import torch.optim as optim -import torchvision.transforms as transforms -import xarray as xr -from torch.utils.data import DataLoader, Dataset - -from graph_weather import GraphWeatherForecaster -from graph_weather.data import const -from graph_weather.models.losses import NormalizedMSELoss - - -class XrDataset(Dataset): - def __init__(self): - super().__init__() - with open("hf_forecasts.json", "r") as f: - files = json.load(f) - self.filepaths = [ - "zip:///::https://huggingface.co/datasets/openclimatefix/gfs-reforecast/resolve/main/" - + f - for f in files - ] - self.data = xr.open_mfdataset( - self.filepaths, engine="zarr", concat_dim="reftime", combine="nested" - ).sortby("reftime") - - def __len__(self): - return len(self.filepaths) - - def __getitem__(self, item): - start_idx = np.random.randint(0, 15) - data = self.data.isel(reftime=item, time=slice(start_idx, start_idx + 1)) - - start = data.isel(time=0) - end = data.isel(time=1) - # Stack the data into a large data cube - input_data = np.stack( - [ - (start[f"{var}"].values - const.FORECAST_MEANS[f"{var}"]) - / (const.FORECAST_STD[f"{var}"] + 0.0001) - for var in start.data_vars - if "mb" in var or "surface" in var - ], - axis=-1, - ) - input_data = np.nan_to_num(input_data) - assert not np.isnan(input_data).any() - output_data = np.stack( - [ - (end[f"{var}"].values - const.FORECAST_MEANS[f"{var}"]) - / (const.FORECAST_STD[f"{var}"] + 0.0001) - for var in end.data_vars - if "mb" in var or "surface" in var - ], - axis=-1, - ) - output_data = np.nan_to_num(output_data) - assert not np.isnan(output_data).any() - transform = transforms.Compose([transforms.ToTensor()]) - # Normalize now - return ( - transform(input_data).transpose(0, 1).reshape(-1, input_data.shape[-1]), - transform(output_data).transpose(0, 1).reshape(-1, input_data.shape[-1]), - ) - - -with open("hf_forecasts.json", "r") as f: - files = json.load(f) -files = [ - "zip:///::https://huggingface.co/datasets/openclimatefix/gfs-reforecast/resolve/main/" + f - for f in files -] -data = ( - xr.open_zarr(files[0], consolidated=True).isel(time=0) - # .coarsen(latitude=8, boundary="pad") - # .mean() - # .coarsen(longitude=8) - # .mean() -) -print(data) -# print("Done coarsening") -lat_lons = np.array(np.meshgrid(data.latitude.values, data.longitude.values)).T.reshape(-1, 2) -device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu") -# Get the variance of the variables -feature_variances = [] -for var in data.data_vars: - if "mb" in var or "surface" in var: - feature_variances.append(const.FORECAST_DIFF_STD[var] ** 2) -criterion = NormalizedMSELoss( - lat_lons=lat_lons, feature_variance=feature_variances, device=device -).to(device) -means = [] -dataset = DataLoader(XrDataset(), batch_size=1, num_workers=32) -model = GraphWeatherForecaster(lat_lons, feature_dim=597, num_blocks=6).to(device) -optimizer = optim.AdamW(model.parameters(), lr=0.001) -print("Done Setup") -import time - -for epoch in range(100): # loop over the dataset multiple times - running_loss = 0.0 - start = time.time() - print(f"Start Epoch: {epoch}") - for i, data in enumerate(dataset): - # get the inputs; data is a list of [inputs, labels] - inputs, labels = data[0].to(device), data[1].to(device) - # zero the parameter gradients - optimizer.zero_grad() - - # forward + backward + optimize - outputs = model(inputs) - - loss = criterion(outputs, labels) - loss.backward() - optimizer.step() - - # print statistics - running_loss += loss.item() - end = time.time() - print( - f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / (i + 1):.3f} Time: {end - start} sec" - ) - if epoch % 5 == 0: - assert not np.isnan(running_loss) - model.push_to_hub( - "graph-weather-forecaster-2.0deg", - organization="openclimatefix", - commit_message=f"Add model Epoch={epoch}", - ) - -print("Finished Training")