-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Graph Neural Network tutorial #2
Comments
Hi and thanks for the interest! The Quickstart section of the readme demonstrates a GNN already and you can find more in the examples directory. I strongly recommend using simpler architectures (e.g. APPNP) when processing a lot of edges. The corresponding tutorial will be uploaded by Monday with the next commit. I will update this issue at that time. Feel free to point out errors, missing functionality, or other things that are hard to use/understand when using the library. P.S. Currently, automatic dataset downloading is disabled (it will be back on in a week or so, once we migrate to the new approach), but you can train architectures with your own data just fine. Edit: Apparently the quickstart is a little outdated in terms of dataset management. (But the architecture itself should work just fine as-is.) |
Added a first take on a GNN tutorial. Can you give some feedback @gavalian (or whoever stumbles upon this issue, don't hesitate to re-open it in the future) on things that are difficult to understand? |
Hi @maniospas, |
Hi again, Issue #1 already mentions the prospect of providing graph classification, but I have been concentrating on making sure that everything else works correctly and have not gotten yet to creating and testing code for this capability. If you don't mind manually calling the backpropagation, you can perform graph clasisifcation using the general ModelBuilder builder = new LayeredBuilder()
.var("A") // set the graph as an architecture input. this is the second input, as the layered builder already defines a first one "h0" (not to be confused with any "h{0}") from its constructor to hold
.layer(...) // define your architecture per the tutorial
....
.layer("h{l+1}=softmax(mean(h{l}, row))") // or mean(h{l}, row), better graph pooling will probably be provided in the future
.out("h{l}"); // set the model to output the outcome of the top layer
Model model = builder.getModel().init(new XavierNormal());
BatchOptimizer optimizer = new BatchOptimizer(new Adam(0.01));
Loss loss = new CategoricalCrossEntropy();
for(int epoch=0; epoch<300; epoch++) {
for(int graphId=0; graphId<graphLabels.size(); graphId++) {
Matrix adjacency = graphMatrices.get(graphId); // retrieve from a list
Matrix nodeFeatures = graphNodeFeatures.get(graphId); // retrieve from a list
Tensor graphLabel = graphLabels.get(graphId); // one-hot encoding of graph labels, e.g. created with Tensor label = new DenseTensor(numClasses).put(class, 1.0);
model.train(loss, optimizer,
Arrays.asList(nodeFeatures, adjacency),
Arrays.asList(graphLabel));
}
optimizer.updateAll();
} Edit: Code improvements, apparently softmax is not implemented for simple tensors (but works well for matrices). Will push a preliminary working version tomorrow. |
I'm working on a generic example, if it does work I can forward you the code in case you want to post it as an example. Meanwhile, the LayeredBuilder() does not have a method layer(....). |
Thanks a lot, I would really appreciate a coherent example if you make one work. I pushed changes to make the code work. The issue was that '.var(...)' returned a base I also added a simple graph classification example that is a more complete version of the above take. Also added graph classification tutorial, though as a first take it could be missing things or be unrefined. |
I tried the example, and everything works fine (after I updated the dependency). However, the network fails to learn. I wrote a small data provider class that generates an object trajectory along with a false trajectory and with labels true and false. When I run the graphs through the training it converges on accuracy of 50%, (in other words, does not learn), however when I run the same data set through MLP network, it learns the dataset with accuracy and performs on tests sample with an accuracy of 99.999%. https://drive.google.com/file/d/10lYVnclZzdYCflScdwbqJSQSngwXu7oD/view?usp=sharing This is a simple one-dimensional trajectory prediction (I'm trying to get the simple case working), my goal is to eventually get 3-D trajectory classification working, the reason for using Graphs instead of MLP is that not all nodes will be present along the trajectory which leads to different graph sizes. |
Thanks a lot for the example. I am looking into it and will update this issue. |
It turns out that mean pooling is too naive for this application. For the time being, I finished adding support for the sort pooling operation of:
Consider this reply a tentative first working take on this type of architecture. In the next commits, I will also add your example in the code base @gavalian , as it provides a very interesting use case that is easy to experiment with. ArchitectureSorting can be integrated as in the following example (don't forget to upgrade to the latest version of the library from jitpack first). I tested the snippet locally and it yields approx. 95% accuracy on the above setting after 500 epochs of training - this is not as impressive as near-perfect MLP performance, but could be an acceptable in practice. I would be interested in hearing further insights. By the way, don't stop training if the architecture keeps producing random or worse-than-random test accuracy before epoch 50 - this happens because sorting overcomes thresholds that keep changing the understanding of the graph's structure. As learning converges so does the internally understood ordering of nodes and there is a point after which accuracy skyrockets and remains high. To explain intuitively the concept of sort pooling due to lack of a respective tutorial for the time being: the reduced hyperparameter keeps only that many topologically important nodes, ordered in terms of their importance (where importance is measured by latent feature values). The idea is that these nodes include propagated information by other nodes too. long reduced = 5; // input graphs need to have at least that many nodes, lower values decrease accuracy
long hidden = 8; // since this library does not use GPU parallelization, many latent dims reduce speed
ModelBuilder builder = new LayeredBuilder()
.var("A")
.config("features", 1)
.config("classes", 2)
.config("reduced", reduced)
.config("hidden", hidden)
.layer("h{l+1}=relu(A@(h{l}@matrix(features, hidden))+vector(hidden))") // don't forget to add bias vectors to dense transformations
.layer("h{l+1}=relu(A@(h{l}@matrix(hidden, hidden))+vector(hidden))")
.concat(2) // concatenates the outputs of the last 2 layers
.config("hiddenReduced", hidden*2*reduced) // 2* due to concatenation
.operation("z{l}=sort(h{l}, reduced)") // currently, the parser fails to understand full expressions within next step's gather, so we need to create this intermediate variable
.layer("h{l+1}=reshape(h{l}[z{l}], 1, hiddenReduced)") //
.layer("h{l+1}=h{l}@matrix(hiddenReduced, classes)")
.layer("h{l+1}=softmax(h{l}, row)")
.out("h{l}"); For training, the labels should also be cast into row vectors to be compliant with the architecture's outputs. (See next code snippet.) Parallelized TrainingI was too focused on the architecture previously, but for this large of a dataset you can take advantage of multi-core processors to calculate derivatives in parallel during training. JGNN is thread-safe and provides its own simplified thread pool to help you do this per: for(int epoch=0; epoch<500; epoch++) {
// gradient update
for(int graphId=0; graphId<dtrain.adjucency.size(); graphId++) {
int graphIdentifier = graphId;
ThreadPool.getInstance().submit(new Runnable() {
@Override
public void run() {
Matrix adjacency = dtrain.adjucency.get(graphIdentifier);
Matrix features= dtrain.features.get(graphIdentifier);
Tensor graphLabel = dtrain.labels.get(graphIdentifier).asRow(); // Don't forget to cast to the same format as predictions.
model.train(loss, optimizer,
Arrays.asList(features, adjacency),
Arrays.asList(graphLabel));
}
});
}
ThreadPool.getInstance().waitForConclusion(); // waits for all gradients to finish calculating
optimizer.updateAll();
double acc = 0.0;
for(int graphId=0; graphId<dtest.adjucency.size(); graphId++) {
Matrix adjacency = dtest.adjucency.get(graphId);
Matrix features= dtest.features.get(graphId);
Tensor graphLabel = dtest.labels.get(graphId);
if(model.predict(Arrays.asList(features, adjacency)).get(0).argmax()==graphLabel.argmax())
acc += 1;
System.out.println("iter = " + epoch + " " + acc/dtest.adjucency.size());
}
} NotesI am not yet closing this issue, because I need to also update the related tutorial. For more discussion, requests on pooling for graph classification, or requests for more tutorials, please open separate issues. |
This is a very interesting library and I want to try this for my project. I wanted to know
if it's possible to have a Graph Neural Network example in the tutorials?
The text was updated successfully, but these errors were encountered: