Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StandardModel, assertion of backbone #743

Closed
niklasmei opened this issue Sep 9, 2024 · 1 comment
Closed

StandardModel, assertion of backbone #743

niklasmei opened this issue Sep 9, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@niklasmei
Copy link
Collaborator

Describe the bug
When trying to implement a generic custom backbone in the StandardModel function, the StandardModel currently requires the backbone to be an instance of GNN when it should probably only require to be a Model

To Reproduce
Steps to reproduce the behavior:
Run the following code in the current version of GraphNet (version: 1.0.0+3861.gdb349416)

import torch
from graphnet.models import StandardModel
from graphnet.models.graphs import KNNGraph
from graphnet.models.detector.prometheus import PONETriangle
from graphnet.models import Model
from graphnet.models.gnn.gnn import GNN
from torch_geometric.data import Data
from graphnet.models.task import StandardLearnedTask
from graphnet.training.loss_functions import LogCoshLoss

graph_definition = KNNGraph(detector = PONETriangle())

class MyGraphNeTModel(Model):  

    def __init__(self,
                 input_dim : int = 4,
                 output_dim : int = 10):

        super().__init__()
        self._layer = torch.nn.Linear(input_dim, output_dim)

    def forward(self, data: Data) -> torch.Tensor:
        x = data.x
        return self._layer(x)
    
class BinaryClassificationTask(StandardLearnedTask):
    nb_inputs = 1
    default_target_labels = ["target"]
    default_prediction_labels = ["target_pred"]

    def _forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sigmoid(x)
    
mybackbone = MyGraphNeTModel()
task = BinaryClassificationTask(hidden_size = 10, #backbone.nb_outputs,
                                loss_function = LogCoshLoss)

model = StandardModel(
    graph_definition=graph_definition,
    backbone=mybackbone, 
    tasks=task
    )

Note here that MyGraphNetModel is an instance of Model, which causes the error.

Expected behavior
I expect that my StandardModel is just built and ready for further use, but instead I get an assertion error.

Full traceback

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[8], [line 38](vscode-notebook-cell:?execution_count=8&line=38)
     [34](vscode-notebook-cell:?execution_count=8&line=34) mybackbone = MyGraphNeTModel()
     [35](vscode-notebook-cell:?execution_count=8&line=35) task = BinaryClassificationTask(hidden_size = 10, #backbone.nb_outputs,
     [36](vscode-notebook-cell:?execution_count=8&line=36)                                 loss_function = LogCoshLoss)
---> [38](vscode-notebook-cell:?execution_count=8&line=38) model = StandardModel(
     [39](vscode-notebook-cell:?execution_count=8&line=39)     graph_definition=graph_definition,
     [40](vscode-notebook-cell:?execution_count=8&line=40)     backbone=mybackbone, 
     [41](vscode-notebook-cell:?execution_count=8&line=41)     tasks=task
     [42](vscode-notebook-cell:?execution_count=8&line=42)     )

File /mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/utilities/config/model_config.py:335, in ModelConfigSaverMeta.__call__(cls, *args, **kwargs)
    [332](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/utilities/config/model_config.py:332)         return obj
    [334](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/utilities/config/model_config.py:334) # Create object
--> [335](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/utilities/config/model_config.py:335) created_obj = super().__call__(*args, **kwargs)
    [337](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/utilities/config/model_config.py:337) # Get all argument values, including defaults
    [338](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/utilities/config/model_config.py:338) cfg = get_all_argument_values(created_obj.__init__, *args, **kwargs)

File /mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/models/standard_model.py:63, in StandardModel.__init__(self, graph_definition, tasks, backbone, gnn, optimizer_class, optimizer_kwargs, scheduler_class, scheduler_kwargs, scheduler_config)
     [58](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/models/standard_model.py:58)     raise TypeError(
     [59](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/models/standard_model.py:59)         "__init__() missing 1 required keyword argument:'backbone'"
     [60](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/models/standard_model.py:60)     )
     [62](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/models/standard_model.py:62) # Checks
---> [63](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/models/standard_model.py:63) assert isinstance(backbone, GNN)
     [64](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/models/standard_model.py:64) assert isinstance(graph_definition, GraphDefinition)
     [66](https://vscode-remote+wsl-002bubuntu.vscode-resource.vscode-cdn.net/mnt/c/Users/nikla/Desktop/GraphNet/environment/graphnet/src/graphnet/models/standard_model.py:66) # Member variable(s)

AssertionError:

Additional context
Changing the inheritance for MyGraphNetModel from Model to GNN (plus some minor adjustments that come with this change) solved the problem.

@RasmusOrsoe
Copy link
Collaborator

closed by #744

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants