Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GRIT model #777

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open

Add GRIT model #777

wants to merge 16 commits into from

Conversation

pweigel
Copy link
Collaborator

@pweigel pweigel commented Dec 17, 2024

GRIT: "Graph Inductive Biases in Transformers without Message Passing"

This PR includes a new model based on the GRIT transformer. It uses novel methods for encoding graph information for use in sparse multi-head attention blocks. It uses a learned position encoding based on random walk probabilities, which enhances the model's expressivity.

PMLR: https://proceedings.mlr.press/v202/ma23c.html
Paper pre-print: https://arxiv.org/abs/2305.17589

image

Many layers/functions are adapted from the original repository: https://github.com/LiamMa/GRIT/tree/main. The original code uses graphgym to set up most of its modules, so I refactored some things to fit into graphnet. Many of the arguments have been relabeled to be more self-explanatory. In principle, other graph attention mechanisms could be used by replacing the GRIT MHA block.

Since there are a lot of changes, I will quickly summarize the significant new additions and modifications to existing files:

This model has many hyperparameters, but the defaults should provide a good starting point. It should be noted that the GPU memory required to train this model is quite high due to the use of global attention.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant