Skip to content

Commit

Permalink
Merge pull request #302 from asogaard/restructure-models
Browse files Browse the repository at this point in the history
Restructure models
  • Loading branch information
asogaard authored Oct 6, 2022
2 parents 9808766 + 39c9148 commit d109faf
Show file tree
Hide file tree
Showing 5 changed files with 426 additions and 410 deletions.
1 change: 1 addition & 0 deletions examples/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def main():
)
gnn = DynEdge(
nb_inputs=detector.nb_outputs,
global_pooling_schemes=["min", "max", "mean", "sum"],
)
task = EnergyReconstruction(
hidden_size=gnn.nb_outputs,
Expand Down
4 changes: 2 additions & 2 deletions src/graphnet/models/components/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Sequence
from typing import Callable, List, Optional, Sequence, Union

from torch.functional import Tensor

Expand All @@ -13,7 +13,7 @@ def __init__(
nn: Callable,
aggr: str = "max",
nb_neighbors: int = 8,
features_subset: Optional[Sequence] = None,
features_subset: Optional[Union[Sequence[int], List[int]]] = None,
**kwargs,
):
# Check(s)
Expand Down
3 changes: 2 additions & 1 deletion src/graphnet/models/gnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .dynedge import DynEdge, DynEdge_V2, DynEdge_V3
from .convnet import ConvNet
from .dynedge import DynEdge
from .dynedge_jinst import DynEdgeJINST
Loading

0 comments on commit d109faf

Please sign in to comment.