Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 7, 2023
1 parent f0f26d8 commit d6e216e
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions graph_weather/models/graphs/ico.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ def inside_points(vAB, vAC):
j = i + 1
interp_multipliers = (np.arange(1, j) / j)[:, None]
out.append(
np.multiply(interp_multipliers, vAC[i, None]) +
np.multiply(1 - interp_multipliers, vAB[i, None])
np.multiply(interp_multipliers, vAC[i, None])
+ np.multiply(1 - interp_multipliers, vAB[i, None])
)
return np.concatenate(out)

Expand All @@ -298,6 +298,7 @@ def generate_icosphere_graph(resolution=1):
edges = np.unique(np.sort(edges, axis=1), axis=0)
return vertices, edges


def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16, 32, 64)) -> Data:
"""
Generate mapping from lat/lon to icosphere index.
Expand Down Expand Up @@ -325,13 +326,18 @@ def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16, 32, 64)) -
# Check the verticies of each pair are the same up to the resolution
for i in range(len(verticies_per_level) - 1):
for vertex_lower_index, vertex in enumerate(verticies_per_level[i]):
vertex_mapping = np.argmin(np.sum(np.abs(verticies_per_level[i + 1] - vertex), axis=1), axis=0)
vertex_mapping = np.argmin(
np.sum(np.abs(verticies_per_level[i + 1] - vertex), axis=1), axis=0
)
# Change all edge indicies from vertex_lower_index to vertex_mapping
edges_per_level[i + 1][edges_per_level[i + 1] == vertex_lower_index] = vertex_mapping
verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above
verticies = verticies_per_level[-1] # The last layer has all the verticies of the ones above
edges = np.unique(np.sort(np.concatenate(edges_per_level), axis=1), axis=0)
# TODO Create mapping from the lat/lon to the icosphere nodes
ico_graph = Data(x=torch.tensor(verticies, dtype=torch.float), edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous())
ico_graph = Data(
x=torch.tensor(verticies, dtype=torch.float),
edge_index=torch.tensor(edges, dtype=torch.long).t().contiguous(),
)
return ico_graph


Expand Down

0 comments on commit d6e216e

Please sign in to comment.