Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible PyTorch implementation of WL kernel #153

Open
wants to merge 16 commits into
base: master
Choose a base branch
from

Conversation

vladislavalerievich
Copy link
Collaborator

@vladislavalerievich vladislavalerievich commented Oct 28, 2024

This pull request introduces a custom PyTorch implementation of the Weisfeiler-Lehman (WL) kernel to replace the existing Grakel-based implementation. The most important changes include adding new example scripts, creating the custom WL kernel class, and adding tests to ensure correctness and compatibility.

New Implementation of Weisfeiler-Lehman Kernel:

  • grakel_replace/torch_wl_kernel.py: Implemented a custom PyTorch class TorchWLKernel for the WL kernel, including methods to convert NetworkX graphs to sparse adjacency tensors, initialize node labels, perform WL iterations, and compute the kernel matrix. Also added a GraphDataset utility class for preprocessing NetworkX graphs.

Example Scripts:

Testing:

  • tests/test_torch_wl_kernel.py: Added tests to compare the custom WL kernel against Grakel's implementation, check kernel matrix symmetry, handle empty graphs, validate input types, check the same output for reordered graphs, and test single-node graphs.

Copy link
Contributor

@eddiebergman eddiebergman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems really good as mentioned previously in MM!

There's some more changes we'll need to discuss regarding the representation we pass in and how we'll get it to work with BoTorch but we can discuss that later. For now, there's a bug, which from the looks of things, extends to either a bug in GraKel or at least in how its used, based on your comparison test with it.

Reproduction with description:

import networkx as nx
from torch_wl_kernel import GraphDataset, TorchWLKernel

# Create the same graphs as for the Grakel example

G1 = nx.Graph()
G1.add_edges_from([(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)])

G2 = nx.Graph()
G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])
G3 = nx.Graph()
G3.add_edges_from([(0, 1), (1, 3), (3, 2)])

# Process graphs
graphs = GraphDataset.from_networkx([G1, G2, G3])

# Initialize and run WL kernel
wl_kernel = TorchWLKernel(n_iter=2, normalize=False)

K = wl_kernel(graphs)

print("Kernel matrix (pairwise similarities):")
print(K)

# Issue: GraphDataset.from_networkx() relabels nodes independantly which is incorrect,
# we assume the edges all refer to the same, i.e. the 0, 1, 2, 3, 4 in the graphs
# above are not independant of each-other
# ------------------------------------------------------------------
# Below, we re-ordered the edges, placing the first edge at the end.
# This is the same graph, yet the kernel returns something different.
#
#
#     v-------------------------------------------v
#  [(0, 1), (1, 2), (1, 3), (1, 4), (2, 3)]
#          [(1, 2), (1, 3), (1, 4), (2, 3), (0, 1)]
#
# Take a look at the implementation of `nx.convert_node_labels_to_integers()`
# that is used in `GraphDataset.from_networkx()`. We likely need to create
# our own mapping and relabel the nodes as they do.
G1 = nx.Graph()
edges_g1 = [(1, 2), (1, 3), (1, 4), (2, 3), (0, 1)]
G1.add_edges_from(edges_g1)
print(list(G1.nodes()))

G2 = nx.Graph()
G2.add_edges_from([(0, 1), (1, 2), (2, 3), (3, 4)])
print(list(G2.nodes()))
G3 = nx.Graph()
G3.add_edges_from([(0, 1), (1, 3), (3, 2)])
print(list(G3.nodes()))

# Process graphs
graphs = GraphDataset.from_networkx([G1, G2, G3])
for g in graphs:
    print(g.edges())

# Initialize and run WL kernel
wl_kernel = TorchWLKernel(n_iter=2, normalize=False)

K = wl_kernel(graphs)

print("Kernel matrix (pairwise similarities):")
print(K)

grakel_replace/grakel_wl_usage_example.py Outdated Show resolved Hide resolved
grakel_replace/torch_wl_kernel.py Outdated Show resolved Hide resolved
grakel_replace/torch_wl_kernel.py Show resolved Hide resolved
tests/test_torch_wl_kernel.py Outdated Show resolved Hide resolved
tests/test_torch_wl_kernel.py Outdated Show resolved Hide resolved
Copy link
Contributor

@eddiebergman eddiebergman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks mostly good but I have some questions regarding the definition of the class within the actual Kernel.

Otherwise, looks really solid and the test are really good!

grakel_replace/mixed_single_task_gp.py Outdated Show resolved Hide resolved
Comment on lines +66 to +69
combined_num_cat_kernel = AdditiveKernel(*kernels) if kernels else None

# Create WL kernel for graphs
wl_kernel = TorchWLKernel(n_iter=5, normalize=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing there's no way to really make it such that we could pass the TorchWLKernel to the AdditiveKernel, i.e. you would use it just like any other kernel type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could define a WLKernel class that extends gpytorch.kernels.Kernel and use that class instead of TorchWLKernel.

@vladislavalerievich vladislavalerievich added the enhancement New feature or request label Nov 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

2 participants