Skip to content

Commit

Permalink
hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed May 1, 2024
1 parent 6840457 commit 71b56a4
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions src/graphnet/datasets/prometheus_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,29 +36,33 @@ def _prepare_args(
"""Prepare arguments for dataset.
Args:
backend: backend of dataset. Either "parquet" or "sqlite"
backend: backend of dataset. Either "parquet" or "sqlite".
features: List of features from user to use as input.
truth: List of event-level truth variables from user.
Returns: Dataset arguments, train/val selection, test selection
"""
if backend == 'sqlite':
if backend == "sqlite":
dataset_paths = glob(os.path.join(self.dataset_dir, "*.db"))
assert len(dataset_paths) == 1
dataset_path = dataset_paths[0]
event_nos = query_database(
database=dataset_path,
query=f"SELECT event_no FROM {self._truth_table[0]}"
database=dataset_path,
query=f"SELECT event_no FROM {self._truth_table[0]}",
)
train_val, test = train_test_split(
event_nos["event_no"].tolist(),
test_size=0.10,
random_state=42,
shuffle=True,
)
elif backend == 'parquet':
elif backend == "parquet":
dataset_path = self.dataset_dir
n_batches = len(glob(os.path.join(dataset_path,self._truth_table,'*.parquet')))
n_batches = len(
glob(
os.path.join(dataset_path, self._truth_table, "*.parquet")
)
)
train_val, test = train_test_split(
np.arange(0, n_batches),
test_size=0.10,
Expand Down

0 comments on commit 71b56a4

Please sign in to comment.