Skip to content

Commit

Permalink
Move plotting to utils
Browse files Browse the repository at this point in the history
Still have an issue with the edges being redundant
  • Loading branch information
jacobbieker committed Dec 8, 2023
1 parent 7aa7b41 commit f2eb1b3
Show file tree
Hide file tree
Showing 2 changed files with 466 additions and 13 deletions.
92 changes: 79 additions & 13 deletions graph_weather/models/graphs/ico.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
import numpy as np
import torch
from torch_geometric.data import Data
from graph_weather.models.graphs.utils import deg2rad, latlon2xyz, xyz2latlon, get_edge_len
from sklearn.neighbors import NearestNeighbors



def icosphere(nu=1, nr_verts=None):
Expand Down Expand Up @@ -299,7 +302,7 @@ def generate_icosphere_graph(resolution=1):
return vertices, edges


def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16, 32, 64)) -> Data:
def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirectional=True) -> Data:
"""
Generate mapping from lat/lon to icosphere index.
Expand All @@ -315,6 +318,11 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16, 32, 64)) -
lat_lons: List of (lat,lon) points
resolutions: Icosphere resolution levels, first 7 levels correspond to Graphcast levels, in ascending order of resolution
"""
import numpy as np
import torch
from torch_geometric.data import Data
from graph_weather.models.graphs.utils import deg2rad, latlon2xyz, xyz2latlon, get_edge_len
from sklearn.neighbors import NearestNeighbors
num_latlons = len(lat_lons)
verticies_per_level = []
edges_per_level = []
Expand All @@ -324,22 +332,80 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16, 32, 64)) -
edges_per_level.append(edges)

# Check the verticies of each pair are the same up to the resolution
for i in range(len(verticies_per_level) - 1):
for i in range(len(verticies_per_level) - 2):
#print(edges_per_level[i])
for vertex_lower_index, vertex in enumerate(verticies_per_level[i]):
vertex_mapping = np.argmin(
np.sum(np.abs(verticies_per_level[i + 1] - vertex), axis=1), axis=0
)
# Change all edge indicies from vertex_lower_index to vertex_mapping
edges_per_level[i + 1][edges_per_level[i + 1] == vertex_lower_index] = vertex_mapping
verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above
edges = np.unique(np.sort(np.concatenate(edges_per_level), axis=1), axis=0)
# Go through each index in the current level, finding the closest vertex in the next level
# Should check all verticies in the next level, and find the closest one
for vertex_upper_index, upper_vertex in enumerate(verticies_per_level[-1]):
multiple_equals = 0
if np.all(vertex == upper_vertex):
multiple_equals += 1
# Manually go through and update the edge tuples to have the vertex_upper_index instead of vertex_lower_index
for edge_index, edge in enumerate(edges_per_level[i]):
if edge[0] == vertex_lower_index:
edges_per_level[i][edge_index][0] = vertex_upper_index
if edge[1] == vertex_lower_index:
edges_per_level[i][edge_index][1] = vertex_upper_index
#print(edges_per_level[i][edges_per_level[i] == vertex_lower_index])
# The vertex is the same, so the edges in the current level that equal vertex_lower_index should be changed to equal vertex_upper_index
#edges_per_level[i][edges_per_level[i] == vertex_upper_index] = vertex_upper_index
if multiple_equals > 1:
print(f"Multiple equals: {multiple_equals}")
#print(edges_per_level[i])
#print("------------------")
verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above
edges = np.sort(np.concatenate(edges_per_level), axis=1)
print(f"Number of edges: {len(edges)}")
if bidirectional:
# Need to add the flipped version of the elements in edges to edges to have the bidirectional graph
edges = np.concatenate([edges, edges[:, [1, 0]]], axis=0)
# TODO Create mapping from the lat/lon to the icosphere nodes
ico_graph = Data(
pos=torch.tensor(verticies, dtype=torch.float),
edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous(),
)
print(f"Now Number of edges: {len(edges)}")
print(f"Number of unique edges: {len(np.unique(edges, axis=0))}")
print(f"Max number of repeated edges: {np.max(np.unique(edges, axis=0, return_counts=True)[1])}")
u, c = np.unique(edges, axis=0, return_counts=True)
print(f"First 20 duplicates: {u[c > 1][:20]}")
xyz_grid = latlon2xyz(torch.tensor(lat_lons, dtype=torch.float))
# Find the closest vertex to each point
vertex_mapping = np.argmin(np.sum(np.abs(verticies - xyz_grid[:, None]), axis=2), axis=1)
# Features will need to be in the same lat/lon order as given, and added to the verticies
ico_graph = Data(pos=torch.tensor(verticies, dtype=torch.float),
edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous())
max_edge_len = np.max(get_edge_len(ico_graph.pos[edges_per_level[-1][:, 0]], ico_graph.pos[edges_per_level[-1][:, 1]]))
# create the grid2mesh bipartite graph
cartesian_grid = latlon2xyz(lat_lons)
n_nbrs = 4
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(ico_graph.pos)
distances, indices = neighbors.kneighbors(cartesian_grid)

src, dst = [], []
for i in range(len(cartesian_grid)):
for j in range(n_nbrs):
if distances[i][j] <= 0.6 * max_edge_len:
src.append(i)
dst.append(indices[i][j])
# Check that the graph is valid
ico_graph.validate(raise_on_error=True)
return ico_graph

generate_icosphere_mapping([(0,0), (0,1), (1,0), (1,1)])
def get_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data):

max_edge_len = np.max(get_edge_len(mesh.pos[mesh.edge_index[:,0]], mesh.pos[mesh.edge_index[:,1]]))

# create the grid2mesh bipartite graph
cartesian_grid = latlon2xyz(lat_lons)
n_nbrs = 4
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(mesh.pos)
distances, indices = neighbors.kneighbors(cartesian_grid)

src, dst = [], []
for i in range(len(cartesian_grid)):
for j in range(n_nbrs):
if distances[i][j] <= 0.6 * max_edge_len:
src.append(i)
dst.append(indices[i][j])

def generate_latent_ico_graph(h3_mapping, h3_distances):
"""
Expand Down
Loading

0 comments on commit f2eb1b3

Please sign in to comment.