Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GenCast's Processor #119

Merged
merged 16 commits into from
Jul 11, 2024
Merged

GenCast's Processor #119

merged 16 commits into from
Jul 11, 2024

Conversation

gbruno16
Copy link
Contributor

@gbruno16 gbruno16 commented Jul 2, 2024

Pull Request

Description

From the paper:

The Processor is a graph transformer model operating on a spherical mesh that computes neighbourhood-based self-attention. Unlike the multimesh used in GraphCast, the mesh in GenCast is a 6-times refined icosahedral mesh as defined in Lam et al. (2023), with 41,162 nodes and 246,960 edges. The Processor consists of 16 consecutive standard transformer blocks (Nguyen and Salazar, 2019; Vaswani et al., 2017), with a feature dimension equal to 512. The 4-head self-attention mechanism in each block is such that each node in the mesh attends to itself and to all other nodes within its 32-hop neighbourhood on the mesh.

Transformers

In this PR, there are two different versions of the transformer blocks:

  • not-sparse (default): Uses PyG as the backend. It is described in the paper "Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification" and can also handle edge features.
  • sparse (experimental): Uses DGL. It is a classical transformer that performs multi-head attention utilizing the mask's sparsity and does not include edge features in the computations. The DGL version should be faster and more memory-efficient when the mesh has more edges.

Conditional Layer Normalization

Every LayerNorm layer is replaced by a custom module: an element-wise affine transformation is applied to the output of the LayerNorm, with parameters computed as linears of Fourier embeddings of noise levels.

K-hop Neighbours

The k-hop mesh graph is now computed using sparse multiplications of the adjacency matrix instead of relying on the PyG implementation.

@gbruno16 gbruno16 self-assigned this Jul 2, 2024
Copy link
Member

@jacobbieker jacobbieker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me! Just a minor nitpick. Nice implementation!

graph_weather/models/gencast/graph/graph_builder.py Outdated Show resolved Hide resolved
)
else:
if not has_dgl:
raise ValueError("Please install DGL to use sparsity.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

@gbruno16 gbruno16 merged commit d8758f3 into openclimatefix:main Jul 11, 2024
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants