Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Dec 9, 2023
1 parent 86480bd commit ed9a087
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 3 additions & 1 deletion graph_weather/models/graphs/ico.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ def generate_icosphere_graph(resolution=1):
return vertices, edges


def generate_icosphere_mapping(lat_lons, resolutions=(1, 2, 4, 8, 16), bidirectional=True) -> Tuple[HeteroData, Data, HeteroData]:
def generate_icosphere_mapping(
lat_lons, resolutions=(1, 2, 4, 8, 16), bidirectional=True
) -> Tuple[HeteroData, Data, HeteroData]:
"""
Generate mapping from lat/lon to icosphere index.
Expand Down
12 changes: 5 additions & 7 deletions graph_weather/models/graphs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def add_node_features(graph: Data, pos: Tensor) -> Data:
return graph


def generate_grid_to_mesh(lat_lons: torch.Tensor, mesh: Data, max_edge_length: Optional[float] = None) -> HeteroData:
def generate_grid_to_mesh(
lat_lons: torch.Tensor, mesh: Data, max_edge_length: Optional[float] = None
) -> HeteroData:
if max_edge_length is None:
max_edge_len = np.max(
get_edge_len(mesh.pos[mesh.edge_index[:, 0]], mesh.pos[mesh.edge_index[:, 1]])
Expand Down Expand Up @@ -359,16 +361,12 @@ def generate_mesh_to_grid(lat_lons: torch.Tensor, mesh: Data):
# create the mesh2grid bipartite graph
cartesian_grid = latlon2xyz(lat_lons)
n_nbrs = 1
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(
mesh.pos
)
neighbors = NearestNeighbors(n_neighbors=n_nbrs).fit(mesh.pos)
_, indices = neighbors.kneighbors(cartesian_grid)
indices = indices.flatten()

src = [
p
for i in indices
for p in mesh.pos[i] # TODO Need to fix this to be the faces in the mesh
p for i in indices for p in mesh.pos[i] # TODO Need to fix this to be the faces in the mesh
]
dst = [i for i in range(len(cartesian_grid)) for _ in range(3)]

Expand Down

0 comments on commit ed9a087

Please sign in to comment.