-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible PyTorch implementation of WL kernel #153
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems really good as mentioned previously in MM!
There's some more changes we'll need to discuss regarding the representation we pass in and how we'll get it to work with BoTorch but we can discuss that later. For now, there's a bug, which from the looks of things, extends to either a bug in GraKel or at least in how its used, based on your comparison test with it.
Reproduction with description:
import networkx as nx
from torch_wl_kernel import GraphDataset, TorchWLKernel
# Create the same graphs as for the Grakel example
G1 = nx.Graph()
G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)])
G2 = nx.Graph()
G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])
G3 = nx.Graph()
G3.add_edges_from([(0, 1), (1, 3), (3, 2)])
# Process graphs
graphs = GraphDataset.from_networkx([G1, G2, G3])
# Initialize and run WL kernel
wl_kernel = TorchWLKernel(n_iter=2, normalize=False)
K = wl_kernel(graphs)
print("Kernel matrix (pairwise similarities):")
print(K)
# Issue: GraphDataset.from_networkx() relabels nodes independantly which is incorrect,
# we assume the edges all refer to the same, i.e. the 0, 1, 2, 3, 4 in the graphs
# above are not independant of each-other
# ------------------------------------------------------------------
# Below, we re-ordered the edges, placing the first edge at the end.
# This is the same graph, yet the kernel returns something different.
#
#
# v-------------------------------------------v
# [(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]
# [(1, 2), (1, 3), (1, 4), (2, 3), (0, 1)]
#
# Take a look at the implementation of `nx.convert_node_labels_to_integers()`
# that is used in `GraphDataset.from_networkx()`. We likely need to create
# our own mapping and relabel the nodes as they do.
G1 = nx.Graph()
edges_g1 = [(1, 2), (1, 3), (1, 4), (2, 3), (0, 1)]
G1.add_edges_from(edges_g1)
print(list(G1.nodes()))
G2 = nx.Graph()
G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])
print(list(G2.nodes()))
G3 = nx.Graph()
G3.add_edges_from([(0, 1), (1, 3), (3, 2)])
print(list(G3.nodes()))
# Process graphs
graphs = GraphDataset.from_networkx([G1, G2, G3])
for g in graphs:
print(g.edges())
# Initialize and run WL kernel
wl_kernel = TorchWLKernel(n_iter=2, normalize=False)
K = wl_kernel(graphs)
print("Kernel matrix (pairwise similarities):")
print(K)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks mostly good but I have some questions regarding the definition of the class within the actual Kernel.
Otherwise, looks really solid and the test are really good!
combined_num_cat_kernel = AdditiveKernel(*kernels) if kernels else None | ||
|
||
# Create WL kernel for graphs | ||
wl_kernel = TorchWLKernel(n_iter=5, normalize=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing there's no way to really make it such that we could pass the TorchWLKernel
to the AdditiveKernel
, i.e. you would use it just like any other kernel type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could define a WLKernel
class that extends gpytorch.kernels.Kernel
and use that class instead of TorchWLKernel
.
This pull request introduces a custom PyTorch implementation of the Weisfeiler-Lehman (WL) kernel to replace the existing Grakel-based implementation. The most important changes include adding new example scripts, creating the custom WL kernel class, and adding tests to ensure correctness and compatibility.
New Implementation of Weisfeiler-Lehman Kernel:
grakel_replace/torch_wl_kernel.py
: Implemented a custom PyTorch classTorchWLKernel
for the WL kernel, including methods to convert NetworkX graphs to sparse adjacency tensors, initialize node labels, perform WL iterations, and compute the kernel matrix. Also added aGraphDataset
utility class for preprocessing NetworkX graphs.Example Scripts:
grakel_replace/grakel_wl_usage_example.py
: Added an example script demonstrating the usage of the Grakel library for computing the WL kernel matrix.grakel_replace/torch_wl_usage_example.py
: Added an example script demonstrating the usage of the new PyTorch-based WL kernel implementation.Testing:
tests/test_torch_wl_kernel.py
: Added tests to compare the custom WL kernel against Grakel's implementation, check kernel matrix symmetry, handle empty graphs, validate input types, check the same output for reordered graphs, and test single-node graphs.