Skip to content

Commit

Permalink
final bug fixes for paper
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric Moreno committed May 13, 2020
1 parent f6a2d4d commit 8ea30c7
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions IN_dataGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,20 @@ def main(args):

for sub_X,sub_Y,sub_Z in tqdm.tqdm(data_train.generate_data(),total=n_train/batch_size):
training = sub_X[2]
training_neu = sub_X[1]
#training_neu = sub_X[1]
training_sv = sub_X[3]
target = sub_Y[0]
spec = sub_Z[0]
trainingv = (torch.FloatTensor(training)).cuda()
trainingv_neu = (torch.FloatTensor(training_neu)).cuda()
#trainingv_neu = (torch.FloatTensor(training_neu)).cuda()
trainingv_sv = (torch.FloatTensor(training_sv)).cuda()
targetv = (torch.from_numpy(np.argmax(target, axis = 1)).long()).cuda()
optimizer.zero_grad()
#out = gnn(trainingv.cuda(), trainingv_sv.cuda())

#Input training dataset
if sv_branch:
out = gnn(trainingv.cuda(), trainingv_sv.cuda(), trainingv_neu.cuda())
out = gnn(trainingv.cuda(), trainingv_sv.cuda())
else:
out = gnn(trainingv.cuda())

Expand All @@ -213,18 +213,18 @@ def main(args):

for sub_X,sub_Y,sub_Z in tqdm.tqdm(data_val.generate_data(),total=n_val/batch_size):
training = sub_X[2]
training_neu = sub_X[1]
#training_neu = sub_X[1]
training_sv = sub_X[3]
target = sub_Y[0]
spec = sub_Z[0]
trainingv = (torch.FloatTensor(training)).cuda()
trainingv_neu = (torch.FloatTensor(training_neu)).cuda()
#trainingv_neu = (torch.FloatTensor(training_neu)).cuda()
trainingv_sv = (torch.FloatTensor(training_sv)).cuda()
targetv = (torch.from_numpy(np.argmax(target, axis = 1)).long()).cuda()

#Input validation dataset
if sv_branch:
out = gnn(trainingv.cuda(), trainingv_sv.cuda(), trainingv_neu.cuda())
out = gnn(trainingv.cuda(), trainingv_sv.cuda())
else:
out = gnn(trainingv.cuda())

Expand Down

0 comments on commit 8ea30c7

Please sign in to comment.