Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple graph generation from model running #76

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a0095e8
Start on adding multiple graph generations
jacobbieker Nov 22, 2023
7c1141c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2023
71b481d
Update workflow versions
jacobbieker Nov 22, 2023
7b19195
Merge remote-tracking branch 'origin/jacob/graph-making' into jacob/g…
jacobbieker Nov 22, 2023
c7b85ae
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2023
0f98e3c
Start on splitting out graph from Encoder
jacobbieker Nov 28, 2023
96272c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 28, 2023
23e718a
Add more notes on GraphCast
jacobbieker Dec 6, 2023
1f3ffda
Add TODOs
jacobbieker Dec 6, 2023
03d877a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2023
bab39e0
Add fix for nu=2
jacobbieker Dec 6, 2023
0349106
Merge remote-tracking branch 'origin/jacob/graph-making' into jacob/g…
jacobbieker Dec 6, 2023
29d3986
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 6, 2023
943cdff
Add correctly generating icosphere graph
jacobbieker Dec 7, 2023
f0f26d8
Merge remote-tracking branch 'origin/jacob/graph-making' into jacob/g…
jacobbieker Dec 7, 2023
d6e216e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2023
2a5299e
Change to position for verticies
jacobbieker Dec 7, 2023
e6b94f6
Merge remote-tracking branch 'origin/jacob/graph-making' into jacob/g…
jacobbieker Dec 7, 2023
7aa7b41
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2023
f2eb1b3
Move plotting to utils
jacobbieker Dec 8, 2023
9cdb87f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2023
d705712
Refactor grid to mesh and mesh to grid
jacobbieker Dec 9, 2023
322567b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 9, 2023
a0c28f9
Update graph generation
jacobbieker Dec 9, 2023
86480bd
Merge remote-tracking branch 'origin/jacob/graph-making' into jacob/g…
jacobbieker Dec 9, 2023
ed9a087
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/workflows.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ jobs:
fail-fast: true
matrix:
os: [ubuntu-latest]
python-version: [3.8, 3.9, "3.10"]
torch-version: [1.13.0, 2.0.0]
python-version: ["3.10", "3.11"]
torch-version: [2.0.0, 2.1.0]
include:
- torch-version: 1.13.0
torchvision-version: 0.14.0
- torch-version: 2.0.0
torchvision-version: 0.15.1
- torch-version: 2.1.0
torchvision-version: 0.16.0

steps:
- uses: actions/checkout@v2
Expand Down
Empty file added graph_weather/data/utils.py
Empty file.
1 change: 1 addition & 0 deletions graph_weather/models/graphs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Set of graph classes for generating different meshes"""
76 changes: 76 additions & 0 deletions graph_weather/models/graphs/hexagonal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Generate hexagonal global grid using Uber's H3 library."""
import h3
import numpy as np
import torch
from torch_geometric.data import Data


def generate_hexagonal_grid(resolution: int = 2) -> np.ndarray:
"""Generate hexagonal global grid using Uber's H3 library.

Args:
resolution: H3 resolution level

Returns:
Hexagonal grid
"""
base_h3_grid = sorted(list(h3.uncompact(h3.get_res0_indexes(), resolution)))
base_h3_map = {h_i: i for i, h_i in enumerate(base_h3_grid)}
return np.array(base_h3_grid), base_h3_map


def generate_h3_mapping(lat_lons: list, resolution: int = 2) -> dict:
"""Generate mapping from lat/lon to h3 index.

Args:
lat_lons: List of (lat,lon) points
resolution: H3 resolution level
"""
num_latlons = len(lat_lons)
base_h3_grid = sorted(list(h3.uncompact(h3.get_res0_indexes(), resolution)))
base_h3_map = {h_i: i for i, h_i in enumerate(base_h3_grid)}
h3_grid = [h3.geo_to_h3(lat, lon, resolution) for lat, lon in lat_lons]
h3_mapping = {}
h_index = len(base_h3_grid)
for h in base_h3_grid:
if h not in h3_mapping:
h_index -= 1
h3_mapping[h] = h_index + num_latlons
# Now have the h3 grid mapping, the bipartite graph of edges connecting lat/lon to h3 nodes
# Should have vertical and horizontal difference
h3_distances = []
for idx, h3_point in enumerate(h3_grid):
lat_lon = lat_lons[idx]
distance = h3.point_dist(lat_lon, h3.h3_to_geo(h3_point), unit="rads")
h3_distances.append([np.sin(distance), np.cos(distance)])
h3_distances = torch.tensor(h3_distances, dtype=torch.float)
return base_h3_map, h3_mapping, h3_distances


def generate_latent_h3_graph(base_h3_map: dict, base_h3_grid: dict) -> torch.Tensor:
"""Generate latent h3 graph.

Args:
base_h3_map: Mapping from h3 index to index in latent graph
h3_mapping: Mapping from lat/lon to h3 index
h3_distances: Distances between lat/lon and h3 index

Returns:
Latent h3 graph
"""
# Get connectivity of the graph
edge_sources = []
edge_targets = []
edge_attrs = []
for h3_index in base_h3_grid:
h_points = h3.k_ring(h3_index, 1)
for h in h_points: # Already includes itself
distance = h3.point_dist(h3.h3_to_geo(h3_index), h3.h3_to_geo(h), unit="rads")
edge_attrs.append([np.sin(distance), np.cos(distance)])
edge_sources.append(base_h3_map[h3_index])
edge_targets.append(base_h3_map[h])
edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long)
edge_attrs = torch.tensor(edge_attrs, dtype=torch.float)
# Use heterogeneous graph as input and output dims are not same for the encoder
# Because uniform grid now, don't need edge attributes as they are all the same
return Data(edge_index=edge_index, edge_attr=edge_attrs)
Loading
Loading