Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to parquet dataloading, sampling, batch sampling #742

Merged
merged 7 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,39 @@ def _create_dataloader(
"Unknown dataset encountered during dataloader creation."
)

if "sampler" in dataloader_args.keys():
# If there were no kwargs provided, set it to empty dict
if "sampler_kwargs" not in dataloader_args.keys():
dataloader_args["sampler_kwargs"] = {}
dataloader_args["sampler"] = dataloader_args["sampler"](
dataset, **dataloader_args["sampler_kwargs"]
)
del dataloader_args["sampler_kwargs"]

if "batch_sampler" in dataloader_args.keys():
if "sampler" not in dataloader_args.keys():
raise KeyError(
"When specifying a `batch_sampler`,"
"you must also provide `sampler`."
)
# If there were no kwargs provided, set it to empty dict
if "batch_sampler_kwargs" not in dataloader_args.keys():
dataloader_args["batch_sampler_kwargs"] = {}

batch_sampler = dataloader_args["batch_sampler"](
dataloader_args["sampler"],
**dataloader_args["batch_sampler_kwargs"],
)
dataloader_args["batch_sampler"] = batch_sampler
# Remove extra keys
for key in [
"batch_sampler_kwargs",
"drop_last",
"sampler",
"shuffle",
]:
dataloader_args.pop(key, None)

if dataloader_args is None:
raise AttributeError("Dataloader arguments not provided.")

Expand Down Expand Up @@ -479,7 +512,6 @@ def _infer_selections_on_single_dataset(
.sample(frac=1, replace=False, random_state=self._rng)
.values.tolist()
) # shuffled list

return self._split_selection(all_events)

def _construct_dataset(self, tmp_args: Dict[str, Any]) -> Dataset:
Expand Down
4 changes: 4 additions & 0 deletions src/graphnet/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
if has_torch_package():
import torch.multiprocessing
from .dataset import EnsembleDataset, Dataset, ColumnMissingException
from .samplers import (
RandomChunkSampler,
LenMatchBatchSampler,
)
from .parquet.parquet_dataset import ParquetDataset
from .sqlite.sqlite_dataset import SQLiteDataset

Expand Down
28 changes: 17 additions & 11 deletions src/graphnet/data/dataset/parquet/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
List,
Optional,
Union,
Any,
)

import numpy as np
Expand Down Expand Up @@ -92,7 +93,7 @@ def __init__(
`"10000 random events ~ event_no % 5 > 0"` or `"20% random
events ~ event_no % 5 > 0"`).
graph_definition: Method that defines the graph representation.
cache_size: Number of batches to cache in memory.
cache_size: Number of files to cache in memory.
Must be at least 1. Defaults to 1.
labels: Dictionary of labels to be added to the dataset.
"""
Expand Down Expand Up @@ -123,8 +124,8 @@ def __init__(
self._path: str = self._path
# Member Variables
self._cache_size = cache_size
self._batch_sizes = self._calculate_sizes()
self._batch_cumsum = np.cumsum(self._batch_sizes)
self._chunk_sizes = self._calculate_sizes()
self._chunk_cumsum = np.cumsum(self._chunk_sizes)
self._file_cache = self._initialize_file_cache(
truth_table=truth_table,
node_truth_table=node_truth_table,
Expand Down Expand Up @@ -179,32 +180,37 @@ def _get_event_index(self, sequential_index: int) -> int:
)
return event_index

@property
def chunk_sizes(self) -> List[int]:
"""Return a list of the chunk sizes."""
return self._chunk_sizes

def __len__(self) -> int:
"""Return length of dataset, i.e. number of training examples."""
return sum(self._batch_sizes)
return sum(self._chunk_sizes)

def _get_all_indices(self) -> List[int]:
"""Return a list of all unique values in `self._index_column`."""
files = glob(os.path.join(self._path, self._truth_table, "*.parquet"))
return np.arange(0, len(files), 1)

def _calculate_sizes(self) -> List[int]:
"""Calculate the number of events in each batch."""
"""Calculate the number of events in each chunk."""
sizes = []
for batch_id in self._indices:
for chunk_id in self._indices:
path = os.path.join(
self._path,
self._truth_table,
f"{self.truth_table}_{batch_id}.parquet",
f"{self.truth_table}_{chunk_id}.parquet",
)
sizes.append(len(pol.read_parquet(path)))
return sizes

def _get_row_idx(self, sequential_index: int) -> int:
"""Return the row index corresponding to a `sequential_index`."""
file_idx = bisect_right(self._batch_cumsum, sequential_index)
file_idx = bisect_right(self._chunk_cumsum, sequential_index)
if file_idx > 0:
idx = int(sequential_index - self._batch_cumsum[file_idx - 1])
idx = int(sequential_index - self._chunk_cumsum[file_idx - 1])
else:
idx = sequential_index
return idx
Expand Down Expand Up @@ -241,9 +247,9 @@ def query_table( # type: ignore
columns = [columns]

if sequential_index is None:
file_idx = np.arange(0, len(self._batch_cumsum), 1)
file_idx = np.arange(0, len(self._chunk_cumsum), 1)
else:
file_idx = [bisect_right(self._batch_cumsum, sequential_index)]
file_idx = [bisect_right(self._chunk_cumsum, sequential_index)]

file_indices = [self._indices[idx] for idx in file_idx]

Expand Down
Loading
Loading