From 8ea30c77dfc45a96caaf0f1c589c4831a3b1c5f4 Mon Sep 17 00:00:00 2001 From: Eric Moreno Date: Tue, 12 May 2020 22:16:48 -0700 Subject: [PATCH] final bug fixes for paper --- IN_dataGenerator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/IN_dataGenerator.py b/IN_dataGenerator.py index 15bbea2..a111110 100644 --- a/IN_dataGenerator.py +++ b/IN_dataGenerator.py @@ -187,12 +187,12 @@ def main(args): for sub_X,sub_Y,sub_Z in tqdm.tqdm(data_train.generate_data(),total=n_train/batch_size): training = sub_X[2] - training_neu = sub_X[1] + #training_neu = sub_X[1] training_sv = sub_X[3] target = sub_Y[0] spec = sub_Z[0] trainingv = (torch.FloatTensor(training)).cuda() - trainingv_neu = (torch.FloatTensor(training_neu)).cuda() + #trainingv_neu = (torch.FloatTensor(training_neu)).cuda() trainingv_sv = (torch.FloatTensor(training_sv)).cuda() targetv = (torch.from_numpy(np.argmax(target, axis = 1)).long()).cuda() optimizer.zero_grad() @@ -200,7 +200,7 @@ def main(args): #Input training dataset if sv_branch: - out = gnn(trainingv.cuda(), trainingv_sv.cuda(), trainingv_neu.cuda()) + out = gnn(trainingv.cuda(), trainingv_sv.cuda()) else: out = gnn(trainingv.cuda()) @@ -213,18 +213,18 @@ def main(args): for sub_X,sub_Y,sub_Z in tqdm.tqdm(data_val.generate_data(),total=n_val/batch_size): training = sub_X[2] - training_neu = sub_X[1] + #training_neu = sub_X[1] training_sv = sub_X[3] target = sub_Y[0] spec = sub_Z[0] trainingv = (torch.FloatTensor(training)).cuda() - trainingv_neu = (torch.FloatTensor(training_neu)).cuda() + #trainingv_neu = (torch.FloatTensor(training_neu)).cuda() trainingv_sv = (torch.FloatTensor(training_sv)).cuda() targetv = (torch.from_numpy(np.argmax(target, axis = 1)).long()).cuda() #Input validation dataset if sv_branch: - out = gnn(trainingv.cuda(), trainingv_sv.cuda(), trainingv_neu.cuda()) + out = gnn(trainingv.cuda(), trainingv_sv.cuda()) else: out = gnn(trainingv.cuda())