You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
When trying to implement a generic custom backbone in the StandardModel function, the StandardModel currently requires the backbone to be an instance of GNN when it should probably only require to be a Model
To Reproduce
Steps to reproduce the behavior:
Run the following code in the current version of GraphNet (version: 1.0.0+3861.gdb349416)
import torch
from graphnet.models import StandardModel
from graphnet.models.graphs import KNNGraph
from graphnet.models.detector.prometheus import PONETriangle
from graphnet.models import Model
from graphnet.models.gnn.gnn import GNN
from torch_geometric.data import Data
from graphnet.models.task import StandardLearnedTask
from graphnet.training.loss_functions import LogCoshLoss
graph_definition = KNNGraph(detector = PONETriangle())
class MyGraphNeTModel(Model):
def __init__(self,
input_dim : int = 4,
output_dim : int = 10):
super().__init__()
self._layer = torch.nn.Linear(input_dim, output_dim)
def forward(self, data: Data) -> torch.Tensor:
x = data.x
return self._layer(x)
class BinaryClassificationTask(StandardLearnedTask):
nb_inputs = 1
default_target_labels = ["target"]
default_prediction_labels = ["target_pred"]
def _forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.sigmoid(x)
mybackbone = MyGraphNeTModel()
task = BinaryClassificationTask(hidden_size = 10, #backbone.nb_outputs,
loss_function = LogCoshLoss)
model = StandardModel(
graph_definition=graph_definition,
backbone=mybackbone,
tasks=task
)
Note here that MyGraphNetModel is an instance of Model, which causes the error.
Expected behavior
I expect that my StandardModel is just built and ready for further use, but instead I get an assertion error.
Additional context
Changing the inheritance for MyGraphNetModel from Model to GNN (plus some minor adjustments that come with this change) solved the problem.
The text was updated successfully, but these errors were encountered:
Describe the bug
When trying to implement a generic custom
backbone
in theStandardModel
function, theStandardModel
currently requires thebackbone
to be an instance ofGNN
when it should probably only require to be aModel
To Reproduce
Steps to reproduce the behavior:
Run the following code in the current version of GraphNet (version: 1.0.0+3861.gdb349416)
Note here that
MyGraphNetModel
is an instance ofModel
, which causes the error.Expected behavior
I expect that my
StandardModel
is just built and ready for further use, but instead I get an assertion error.Full traceback
Additional context
Changing the inheritance for
MyGraphNetModel
fromModel
toGNN
(plus some minor adjustments that come with this change) solved the problem.The text was updated successfully, but these errors were encountered: