Skip to content

Commit

Permalink
Fix ensemble dataset functionality and added unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: samadpls <[email protected]>
  • Loading branch information
samadpls committed Feb 9, 2024
1 parent 0b3d3f1 commit 57f49fe
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/graphnet/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,8 @@ def _infer_selections(self) -> Tuple[List[int], List[int]]:
train_selection,
val_selection,
) = self._infer_selections_on_single_dataset(dataset_path)
self._train_selection.extend(train_selection) # type: ignore
self._val_selection.extend(val_selection) # type: ignore
self._train_selection.append(train_selection) # type: ignore
self._val_selection.append(val_selection) # type: ignore
else:
# Infer selection on a single dataset
(
Expand Down
102 changes: 102 additions & 0 deletions tests/data/test_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,105 @@ def test_dataloader_args(
assert (
dm.test_dataloader.batch_size == test_dataloader_kwargs["batch_size"]
)


@pytest.mark.parametrize(
"dataset_ref", [SQLiteDataset, ParquetDataset], indirect=True
)
def test_ensemble_dataset_without_selections(
dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]]
) -> None:
"""Test ensemble dataset functionality without selections.
Args:
dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference,
dataset keyword arguments, and dataloader keyword arguments.
Returns:
None
"""
# Make dataloaders from single dataset
dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup
dm_single = GraphNeTDataModule(
dataset_reference=dataset_ref,
dataset_args=deepcopy(dataset_kwargs),
train_dataloader_kwargs=dataloader_kwargs,
)

# Copy dataset path twice; mimic ensemble dataset behavior
ensemble_dataset_kwargs = deepcopy(dataset_kwargs)
dataset_path = ensemble_dataset_kwargs["path"]
ensemble_dataset_kwargs["path"] = [dataset_path, dataset_path]

# Create dataloaders from multiple datasets
dm_ensemble = GraphNeTDataModule(
dataset_reference=dataset_ref,
dataset_args=ensemble_dataset_kwargs,
train_dataloader_kwargs=dataloader_kwargs,
)

# Test that the ensemble dataloaders contain more batches
assert len(dm_single.train_dataloader) < len(dm_ensemble.train_dataloader)
assert len(dm_single.val_dataloader) < len(dm_ensemble.val_dataloader)


@pytest.mark.parametrize("dataset_ref", [SQLiteDataset, ParquetDataset])
def test_ensemble_dataset_with_selections(
dataset_setup: Tuple[Any, Dict[str, Any], Dict[str, int]]
) -> None:
"""Test ensemble dataset functionality with selections.
Args:
dataset_setup (Tuple[Any, Dict[str, Any], Dict[str, int]]): A tuple containing the dataset reference,
dataset keyword arguments, and dataloader keyword arguments.
Returns:
None
"""
# extract all events
dataset_ref, dataset_kwargs, dataloader_kwargs = dataset_setup
file_path = dataset_kwargs["path"]
selection = extract_all_events_ids(
file_path=file_path, dataset_kwargs=dataset_kwargs
)

# Copy dataset path twice; mimic ensemble dataset behavior
ensemble_dataset_kwargs = deepcopy(dataset_kwargs)
dataset_path = ensemble_dataset_kwargs["path"]
ensemble_dataset_kwargs["path"] = [dataset_path, dataset_path]

# pass two datasets but only one selection; should fail:
with pytest.raises(Exception):
_ = GraphNeTDataModule(
dataset_reference=dataset_ref,
dataset_args=ensemble_dataset_kwargs,
train_dataloader_kwargs=dataloader_kwargs,
selection=selection,
)

# Pass two datasets and two selections; should work:
selection_1 = selection[0:20]
selection_2 = selection[0:10]
dm = GraphNeTDataModule(
dataset_reference=dataset_ref,
dataset_args=ensemble_dataset_kwargs,
train_dataloader_kwargs=dataloader_kwargs,
selection=[selection_1, selection_2],
)
n_events_in_dataloaders = len(dm.train_dataloader.dataset) + len(dm.val_dataloader.dataset) # type: ignore

# Check that the number of events in train/val match
assert n_events_in_dataloaders == len(selection_1) + len(selection_2)

# Pass two datasets, two selections and two test selections; should work
dm2 = GraphNeTDataModule(
dataset_reference=dataset_ref,
dataset_args=ensemble_dataset_kwargs,
train_dataloader_kwargs=dataloader_kwargs,
selection=[selection, selection],
test_selection=[selection_1, selection_2],
)

# Check that the number of events in test dataloaders are correct.
n_events_in_test_dataloaders = len(dm2.test_dataloader.dataset) # type: ignore
assert n_events_in_test_dataloaders == len(selection_1) + len(selection_2)

0 comments on commit 57f49fe

Please sign in to comment.