diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py index cfb814f94..bef153097 100644 --- a/src/graphnet/models/standard_model.py +++ b/src/graphnet/models/standard_model.py @@ -7,6 +7,7 @@ from torch.optim import Adam from graphnet.models.gnn.gnn import GNN +from graphnet.models import Model from .easy_model import EasySyntax from graphnet.models.task import StandardLearnedTask from graphnet.models.graphs import GraphDefinition @@ -25,7 +26,7 @@ def __init__( self, graph_definition: GraphDefinition, tasks: Union[StandardLearnedTask, List[StandardLearnedTask]], - backbone: GNN = None, + backbone: Model = None, gnn: Optional[GNN] = None, optimizer_class: Type[torch.optim.Optimizer] = Adam, optimizer_kwargs: Optional[Dict] = None, @@ -60,7 +61,7 @@ def __init__( ) # Checks - assert isinstance(backbone, GNN) + assert isinstance(backbone, Model) assert isinstance(graph_definition, GraphDefinition) # Member variable(s)