Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

skip connection option added #165

Closed
wants to merge 15 commits into from
Closed

skip connection option added #165

wants to merge 15 commits into from

Conversation

allaffa
Copy link
Collaborator

@allaffa allaffa commented Feb 22, 2023

As the name suggests, the skip connections in deep architecture bypass some of the neural network layers and feed the output of one layer as the input to the following levels. It is a standard module and provides an alternative path for the gradient with backpropagation.

Skip Connections were originally created to tackle various difficulties in various architectures and were introduced even before residual networks. In the case of residual networks or ResNets, skip connections were used to solve the degradation problems (e.g., vanishing gradient), and in the case of dense networks or DenseNets, it ensured feature reusability.

Since the residual block with the skip connection was shown to benefit the training of CNN, I thought it would be good providing it as optional also to stabilize the GNN (after all, the GNN architecture is a generalization of the CNN).

@allaffa allaffa added the enhancement New feature or request label Feb 22, 2023
@allaffa allaffa self-assigned this Feb 22, 2023
@allaffa allaffa linked an issue Feb 22, 2023 that may be closed by this pull request
@allaffa allaffa marked this pull request as draft February 22, 2023 04:03
@allaffa
Copy link
Collaborator Author

allaffa commented Feb 22, 2023

@pzhanggit @jychoi-hpc
I made this PR as a WIP.
For now, the skip connection is allowed for all the layers except for GATv2Conv.
If I set self.heads=1, everything works well. However, for any number of self.heads larger than 1, the code crashes. The multi-heads of the GATv2Conv layer change the dimensional of the variable x across multiple message passing layers. I would like to probe more into the
GATv2Conv layers and see if there is a practical way to apply the skip connection also for this layer.

@allaffa allaffa marked this pull request as ready for review April 14, 2023 15:19
@allaffa
Copy link
Collaborator Author

allaffa commented Apr 14, 2023

@pzhanggit @jychoi-hpc
I activated this PR to kick off the review process :)

Copy link
Member

@jychoi-hpc jychoi-hpc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good to me. As long as we solve the conflict, I think it is ok to merge.

@allaffa
Copy link
Collaborator Author

allaffa commented Apr 14, 2023

It looks good to me. As long as we solve the conflict, I think it is ok to merge.

@jychoi-hpc I manually solved the conflicts and rebase to maintain a linear history on the main branch. Would you mind re-controlling that I di don't remove anything of what you did? Thanks.

Copy link
Collaborator

@pzhanggit pzhanggit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look good to me too. I only have a few minor comments. Will approve once they are fixed and the tests pass.

@@ -100,6 +140,15 @@ def __init__(
if self.initial_bias is not None:
self._set_bias()

self.layers = ModuleList()
for i, (conv, batch_norm) in enumerate(zip(self.convs, self.batch_norms)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove i and enumerate since not used anymore

Comment on lines 296 to 297
count_conv_layers = 0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed?

hydragnn/models/Base.py Show resolved Hide resolved
hydragnn/models/create.py Show resolved Hide resolved
@allaffa allaffa requested review from jychoi-hpc and pzhanggit April 14, 2023 19:55
Copy link
Collaborator

@pzhanggit pzhanggit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update forward function. That skip connection is defined in self.conv_shared but not called in forward function. It is still calling the old self.convs.

hydragnn/models/Base.py Show resolved Hide resolved
@allaffa allaffa marked this pull request as draft April 15, 2023 14:48
@allaffa allaffa marked this pull request as ready for review April 18, 2023 14:05
@allaffa allaffa requested a review from pzhanggit April 18, 2023 14:05
@allaffa
Copy link
Collaborator Author

allaffa commented Apr 18, 2023

@pzhanggit @jychoi-hpc
Thanks. I turned on again the PR for an additional review :-)

@jychoi-hpc
Copy link
Member

I found some error in my modification on ConvSequential. I will get back with a fix soon.

@jychoi-hpc
Copy link
Member

@allaffa I made a fix for what I found: allaffa#4. If you merge this PR in your local repo, it will show here. It looks like unit test is mostly working for me. But, I got errors with MAE checking though:

E           AssertionError: Head RMSE checking failed for 0
E           assert tensor(0.2304) < 0.2

We may need to adjust error bounds.

@allaffa allaffa force-pushed the resnet_gnn branch 2 times, most recently from 96e68a3 to 4415a11 Compare April 26, 2023 19:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Skip connection for ResNet type of GNN
3 participants