diff --git a/main.py b/main.py index 8a844be..df5ffb3 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from NeuroGraph.datasets import NeuroGraphStatic +from NeuroGraph.datasets import NeuroGraphDataset import argparse import torch import torch.nn.functional as F @@ -48,7 +48,7 @@ def logger(info): np.random.seed(args.seed) -dataset = NeuroGraphStatic(root=root, dataset_name= args.dataset) +dataset = NeuroGraphDataset(root=root, name= args.dataset) print(dataset.num_classes) print(len(dataset)) @@ -95,7 +95,7 @@ def test(loader): return correct / len(loader.dataset) val_acc_history, test_acc_history, test_loss_history = [],[],[] -seeds = [123,124,125,126,127,128,129,221,223,224,228,229] +seeds = [123,124] for index in range(args.runs): start = time.time() torch.manual_seed(seeds[index])