Skip to content

Commit

Permalink
adding option to remove SVs
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-moreno authored Nov 13, 2019
1 parent 088ba3d commit b30c3ab
Showing 1 changed file with 18 additions and 15 deletions.
33 changes: 18 additions & 15 deletions IN_dataGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,18 @@
import numpy as np
import pandas as pd
import util
import imp
try:
imp.find_module('setGPU')
import setGPU
except ImportError:
pass
import setGPU
import glob
import sys
import tqdm
import argparse

#sys.path.insert(0, '/nfshome/jduarte/DL4Jets/mpi_learn/mpi_learn/train')
print(torch.__version__)

os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
if os.path.isdir('/bigdata/shared/BumbleB'):
test_path = '/bigdata/shared/BumbleB/convert_20181121_ak8_80x_deepDoubleB_db_pf_cpf_sv_dl4jets_test/'
train_path = '/bigdata/shared/BumbleB/convert_20181121_ak8_80x_deepDoubleB_db_pf_cpf_sv_dl4jets_train_val/'
elif os.path.isdir('/eos/user/w/woodson/IN'):
test_path = '/eos/user/w/woodson/IN/convert_20181121_ak8_80x_deepDoubleB_db_pf_cpf_sv_dl4jets_test/'
train_path = '/eos/user/w/woodson/IN/convert_20181121_ak8_80x_deepDoubleB_db_pf_cpf_sv_dl4jets_train_val/'

test_path = '/bigdata/shared/BumbleB/convert_20181121_ak8_80x_deepDoubleB_db_pf_cpf_sv_dl4jets_test/'
train_path = '/bigdata/shared/BumbleB/convert_20181121_ak8_80x_deepDoubleB_db_pf_cpf_sv_dl4jets_train_val/'
NBINS = 40 # number of bins for loss function
MMAX = 200. # max value
MMIN = 40. # min value
Expand Down Expand Up @@ -99,6 +90,7 @@ def main(args):
label = 'new'
outdir = args.outdir
vv_branch = args.vv_branch
sv_branch = args.sv_branch
os.system('mkdir -p %s'%outdir)

batch_size = 128
Expand All @@ -122,12 +114,22 @@ def main(args):
print("val data:", n_val)
print("train data:", n_train)

from gnn import GraphNetnoSV
from gnn import GraphNet

gnn = GraphNet(N, n_targets, len(params), args.hidden, N_sv, len(params_sv),
if sv_branch:
gnn = GraphNet(N, n_targets, len(params), args.hidden, N_sv, len(params_sv),
vv_branch=int(vv_branch),
De=args.De,
Do=args.Do)
else:
gnn = GraphNetnoSV(N, n_targets, len(params), args.hidden, N_sv, len(params_sv),
sv_branch=int(sv_branch),
vv_branch=int(vv_branch),
De=args.De,
Do=args.Do)


# pre load best model
#gnn.load_state_dict(torch.load('out/gnn_new_best.pth'))

Expand Down Expand Up @@ -221,7 +223,7 @@ def main(args):
loss_vals_validation[m] = l_val
loss_std_validation[m] = np.std(np.array(loss_val))
loss_std_training[m] = np.std(np.array(loss_training))
if m > 5 and all(loss_vals_validation[max(0, m - 5):m] > min(np.append(loss_vals_validation[0:max(0, m - 5)], 200))):
if m > 8 and all(loss_vals_validation[max(0, m - 8):m] > min(np.append(loss_vals_validation[0:max(0, m - 8)], 200))):
print('Early Stopping...')
print(loss_vals_training, '\n', np.diff(loss_vals_training))
break
Expand All @@ -244,6 +246,7 @@ def main(args):

# Required positional arguments
parser.add_argument("outdir", help="Required output directory")
parser.add_argument("sv_branch", help="Required positional argument")
parser.add_argument("vv_branch", help="Required positional argument")

# Optional arguments
Expand Down

0 comments on commit b30c3ab

Please sign in to comment.