diff --git a/graph_weather/models/graphs/ico.py b/graph_weather/models/graphs/ico.py index 7ea449ad..d3a386b3 100644 --- a/graph_weather/models/graphs/ico.py +++ b/graph_weather/models/graphs/ico.py @@ -283,8 +283,8 @@ def inside_points(vAB, vAC): j = i + 1 interp_multipliers = (np.arange(1, j) / j)[:, None] out.append( - np.multiply(interp_multipliers, vAC[i, None]) + - np.multiply(1 - interp_multipliers, vAB[i, None]) + np.multiply(interp_multipliers, vAC[i, None]) + + np.multiply(1 - interp_multipliers, vAB[i, None]) ) return np.concatenate(out) @@ -298,6 +298,7 @@ def generate_icosphere_graph(resolution=1): edges = np.unique(np.sort(edges, axis=1), axis=0) return vertices, edges + def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16, 32, 64)) -> Data: """ Generate mapping from lat/lon to icosphere index. @@ -325,13 +326,18 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16, 32, 64)) - # Check the verticies of each pair are the same up to the resolution for i in range(len(verticies_per_level) - 1): for vertex_lower_index, vertex in enumerate(verticies_per_level[i]): - vertex_mapping = np.argmin(np.sum(np.abs(verticies_per_level[i + 1] - vertex), axis=1), axis=0) + vertex_mapping = np.argmin( + np.sum(np.abs(verticies_per_level[i + 1] - vertex), axis=1), axis=0 + ) # Change all edge indicies from vertex_lower_index to vertex_mapping edges_per_level[i + 1][edges_per_level[i + 1] == vertex_lower_index] = vertex_mapping - verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above + verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above edges = np.unique(np.sort(np.concatenate(edges_per_level), axis=1), axis=0) # TODO Create mapping from the lat/lon to the icosphere nodes - ico_graph = Data(x=torch.tensor(verticies, dtype=torch.float), edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous()) + ico_graph = Data( + x=torch.tensor(verticies, dtype=torch.float), + edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous(), + ) return ico_graph