Skip to content

Commit

Permalink
Add a custom __call__ method to pass graphs during optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
vladislavalerievich committed Nov 21, 2024
1 parent 246f9f6 commit 770c626
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions grakel_replace/mixed_single_task_gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class MixedSingleTaskGP(SingleTaskGP):
_wl_kernel (TorchWLKernel): Instance of the Weisfeiler-Lehman kernel.
_train_graphs (list[nx.Graph]): Training graph instances.
_K_train (Tensor): Precomputed graph kernel matrix for training graphs.
train_inputs (tuple[Tensor, list[nx.Graph]]): Tuple of training inputs.
num_cat_kernel (Module | None): Kernel for numerical/categorical features.
"""

Expand Down Expand Up @@ -118,6 +119,9 @@ def __init__(
self._wl_kernel = wl_kernel or TorchWLKernel(n_iter=5, normalize=True)
self._train_graphs = train_graphs

# Store graphs as part of train_inputs for using them in the __call__ method
self.train_inputs = (train_X, train_graphs)

# Preprocess the training graphs into a compatible format and compute the graph
# kernel matrix
self._train_graph_dataset = GraphDataset.from_networkx(train_graphs)
Expand All @@ -133,6 +137,12 @@ def __init__(

self.num_cat_kernel = num_cat_kernel

def __call__(self, X: Tensor, graphs: list[nx.Graph] = None, **kwargs):
"""Custom __call__ method that retrieves graphs if not explicitly passed."""
if graphs is None: # Use stored graphs from train_inputs if not provided
graphs = self.train_inputs[1]
return super().__call__(X, graphs=graphs, **kwargs)

def forward(self, X: Tensor, graphs: list[nx.Graph]) -> MultivariateNormal:
"""Forward pass to compute the Gaussian Process distribution for given inputs.
Expand Down

0 comments on commit 770c626

Please sign in to comment.