Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flexible options to archetype graphs #19

Merged
merged 24 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
cf91b2a
Handle non-square areas for single-level flat graphs
joeloskarsson Aug 14, 2024
d593a4f
Fix tests
joeloskarsson Aug 14, 2024
c402c5e
Introduce radius graph connection with distance relative to longest edge
joeloskarsson Aug 14, 2024
7e3f1b4
Handle relative connection radius accurately in multiscale graphs
joeloskarsson Sep 3, 2024
844fbf0
Fix relative refinement factor for g2m also for hierarchical graphs
joeloskarsson Sep 3, 2024
c50f00a
Change multiscale archetype to enforce refinement factor 3
joeloskarsson Sep 3, 2024
43a0815
Start work on creating separate refinement factors between grid-mesh …
joeloskarsson Sep 3, 2024
2e80a5c
Clean up separate refinement factors to also work with multiscale graph
joeloskarsson Sep 4, 2024
0a0511e
Only connect grid to bottom level of hierarchical mesh
joeloskarsson Sep 4, 2024
48b4fe0
Fix m2g specification for archetypes
joeloskarsson Sep 4, 2024
661ad4e
Correct doscstrings to match updated archetypes
joeloskarsson Sep 4, 2024
ec43d63
Update docs to match archetype changes
joeloskarsson Sep 4, 2024
9a9d78b
Fix tests
joeloskarsson Sep 4, 2024
f86e186
Add test for new rel_max_dist within_radius parameter
joeloskarsson Sep 4, 2024
4039ccc
Fix some comments and checks
joeloskarsson Sep 4, 2024
5705945
Run pre-commit on docs
joeloskarsson Sep 4, 2024
e2e5af0
Fix typos
joeloskarsson Sep 5, 2024
ba2cc14
Add explanation about why default distance is 0.51d
joeloskarsson Sep 5, 2024
7c999aa
Clarify docstring for split_on_edge_attribute_existance
joeloskarsson Sep 5, 2024
115eeea
Clarify dimension names in mesh creation
joeloskarsson Sep 10, 2024
c484161
Clarify grid coordinates also in docstring for flat mesh graph
joeloskarsson Sep 10, 2024
2ea2117
Update changelog
joeloskarsson Sep 12, 2024
1395a4b
Clear outputs in documentation notebooks
joeloskarsson Sep 12, 2024
bb7f37b
Merge branch 'main' into archetype_changes
joeloskarsson Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,28 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

### Vim ###
# Swap
[._]*.s[a-v][a-z]
[._]*.sw[a-p]
[._]s[a-rt-v][a-z]
[._]ss[a-gi-z]
[._]sw[a-p]

# Session
Session.vim
Sessionx.vim

# Temporary
.netrwhist
*~

# Auto-generated tag files
tags

# Persistent undo
[._]*.un~

# Coc configuration directory
.vim
40 changes: 32 additions & 8 deletions docs/background.ipynb

Large diffs are not rendered by default.

536 changes: 454 additions & 82 deletions docs/creating_the_graph.ipynb

Large diffs are not rendered by default.

106 changes: 69 additions & 37 deletions src/weather_model_graphs/create/archetype.py
leifdenby marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,23 +1,27 @@
from .base import create_all_graph_components


def create_keisler_graph(xy_grid):
def create_keisler_graph(xy_grid, grid_refinement_factor=3):
"""
Create a graph following Keisler (2022, https://arxiv.org/abs/2202.07575) architecture.
Create a flat LAM graph from Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
This graph setup is inspired by the global graph used by Keisler (2022, https://arxiv.org/abs/2202.07575).

This graph is a flat multiscale graph with nearest neighbour connectivity
(8 neighbours) within the mesh. The grid to mesh connectivity connects each mesh node to
the four nearest grid points. The mesh to grid connectivity connects each grid point to the
nearest mesh node.
This graph is a flat single scale graph with nearest neighbour connectivity
(8 neighbours) within the mesh.

TODO: Verify that Keisler does in fact use these g2m and m2g connectivities.
The grid to mesh connectivity connects each mesh node to grid nodes withing
distance 0.51d, where d is the length of diagonal edges between neighbouring
mesh nodes. The choice of 0.51 makes sure that all grid node positions will
be connected to at least one mesh node (see
https://www.desmos.com/calculator/sqqz0ka4ho for a visualization).
The mesh to grid connectivity connects each grid point to the 4 nearest mesh nodes.

Parameters
----------
xy_grid: np.ndarray
2D array of grid point positions.
merge_components: bool
Whether to merge the components of the graph.
grid_refinement_factor: float
Refinement factor between grid points and mesh

Returns
-------
Expand All @@ -27,34 +31,43 @@ def create_keisler_graph(xy_grid):
return create_all_graph_components(
xy=xy_grid,
m2m_connectivity="flat",
m2m_connectivity_kwargs={},
m2g_connectivity="nearest_neighbour",
g2m_connectivity="nearest_neighbours",
m2m_connectivity_kwargs=dict(grid_refinement_factor=grid_refinement_factor),
g2m_connectivity="within_radius",
m2g_connectivity="nearest_neighbours",
g2m_connectivity_kwargs=dict(
rel_max_dist=0.51,
),
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
)


def create_graphcast_graph(xy_grid, refinement_factor=3, max_num_levels=None):
def create_graphcast_graph(
xy_grid, grid_refinement_factor=3, level_refinement_factor=3, max_num_levels=None
):
"""
Create a graph following the Lam et al (2023, https://arxiv.org/abs/2212.12794) GraphCast architecture.
Create a multiscale LAM graph from Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
This graph setup is inspired by the global GraphCast graph used by Lam et al (2023, https://arxiv.org/abs/2212.12794)

This graph is a flat multiscale graph with nearest neighbour connectivity (4 neighbours) with both nearest
neighbour and longer range connections in the mesh, using the `refinement_factor` and `max_num_levels` parameters
to constrain the range-length of the connections. The grid to mesh connectivity connects each mesh node to
to its nearest 4 grid points. The mesh to grid connectivity connects each grid point to the nearest mesh node.
This graph is a flat multiscale graph with neighbour connectivity and longer multi-scale edges.

TODO: Verify that GraphCast does in fact use these g2m and m2g connectivities.
The grid to mesh connectivity connects each mesh node to grid nodes withing
distance 0.51d, where d is the length of diagonal edges between neighbouring
mesh nodes. The choice of 0.51 makes sure that all grid node positions will
be connected to at least one mesh node (see
https://www.desmos.com/calculator/sqqz0ka4ho for a visualization).
The mesh to grid connectivity connects each grid point to the 4 nearest mesh nodes.

Parameters
----------
xy_grid: np.ndarray
2D array of grid point positions.
refinement_factor: int
Refinement factor for longer-range connections in the mesh graph, the
reduction factor in the number of mesh points between levels (in both
x and y directions).
grid_refinement_factor: float
Refinement factor between grid points and bottom level of mesh hierarchy
level_refinement_factor: int
Refinement factor between grid points and bottom level of mesh hierarchy
NOTE: Must be an odd integer >1 to create proper multiscale graph
max_num_levels: int
The number of levels of longer-range connections in the mesh graph.

Expand All @@ -67,39 +80,51 @@ def create_graphcast_graph(xy_grid, refinement_factor=3, max_num_levels=None):
xy=xy_grid,
m2m_connectivity="flat_multiscale",
m2m_connectivity_kwargs=dict(
refinement_factor=refinement_factor, max_num_levels=max_num_levels
grid_refinement_factor=grid_refinement_factor,
level_refinement_factor=level_refinement_factor,
max_num_levels=max_num_levels,
),
m2g_connectivity="nearest_neighbour",
g2m_connectivity="nearest_neighbours",
g2m_connectivity="within_radius",
m2g_connectivity="nearest_neighbours",
Comment on lines +87 to +88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that these are "historical" graphs, but what is the rational behind having a different connectivity in g2m and m2g? (Mostly asking for educatory purposes, I think the code is fine).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I have ever read a thorough explanation, but mostly used this for historical reasons as well. One motivation could be based on what we think g2m and m2g does. g2m is supposed to aggregate grid information up to the mesh. It could then be useful to include a large area around each mesh node as the "information aggregation window", and overlaps in this are not a problem. For m2g, its purpose is to extract the information from the mesh to determine the final prediction in each grid node. At this point we expect this information to be localised to the closest mesh nodes, so in a sense m2g only performs a fancy interpolation between the closest mesh nodes. If we connect to the closest mesh nodes we don't expect mesh nodes further away to contribute with more infromation.

g2m_connectivity_kwargs=dict(
rel_max_dist=0.51,
),
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
)


def create_oscarsson_hierarchical_graph(xy_grid):
def create_oskarsson_hierarchical_graph(
xy_grid, grid_refinement_factor=3, level_refinement_factor=3, max_num_levels=None
):
"""
Create a graph following Oscarsson et al (2023, https://arxiv.org/abs/2309.17370)
Create a LAM graph following Oskarsson et al (2023, https://arxiv.org/abs/2309.17370)
hierarchical architecture.

The mesh graph in this architecture is hierarchical in that each refinement of
longer-range edges are split into different levels. In addition to these same-level
connections the mesh graph contains nearest neighbour connections between
levels (up and down). To distinguish between these these three types of
edge connections each edge has a `direction` attribute (with value "up",
"down", or "same"). In addition the `level` attribute indicates which two levels
"down", or "same"). In addition, the `levels` attribute indicates which two levels
are connected for cross-level edges (e.g. "1>2" for edges between level 1 and 2).

The grid to mesh connectivity connects each mesh node to the four nearest
grid points, and the mesh to grid connectivity connects each grid point to
the nearest mesh node.

TODO: Is this the right connectivity for the g2m and m2g components?
The grid to mesh connectivity connects each mesh node to grid nodes withing
distance 0.51d, where d is the length of diagonal edges between neighbouring
mesh nodes. The choice of 0.51 makes sure that all grid node positions will
be connected to at least one mesh node (see
https://www.desmos.com/calculator/sqqz0ka4ho for a visualization).
The mesh to grid connectivity connects each grid point to the 4 nearest mesh nodes.

Parameters
----------
xy_grid: np.ndarray
2D array of grid point positions.
grid_refinement_factor: float
Refinement factor between grid points and bottom level of mesh hierarchy
level_refinement_factor: float
Refinement factor between grid points and bottom level of mesh hierarchy

Returns
-------
Expand All @@ -109,10 +134,17 @@ def create_oscarsson_hierarchical_graph(xy_grid):
return create_all_graph_components(
xy=xy_grid,
m2m_connectivity="hierarchical",
m2m_connectivity_kwargs=dict(refinement_factor=2, max_num_levels=3),
m2g_connectivity="nearest_neighbour",
g2m_connectivity="nearest_neighbours",
m2m_connectivity_kwargs=dict(
grid_refinement_factor=grid_refinement_factor,
level_refinement_factor=level_refinement_factor,
max_num_levels=max_num_levels,
),
g2m_connectivity="within_radius",
m2g_connectivity="nearest_neighbours",
g2m_connectivity_kwargs=dict(
rel_max_dist=0.51,
),
m2g_connectivity_kwargs=dict(
max_num_neighbours=4,
),
)
Loading
Loading