From 10a79b383236e28cf69ac9e59ff0779a3dfa4bd5 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 3 Apr 2024 20:31:18 +0200 Subject: [PATCH 1/6] set default values for collate_fn in DataModule --- src/graphnet/data/datamodule.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 2e432bcff..42c7bb7c5 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -13,6 +13,7 @@ ParquetDataset, ) from graphnet.utilities.logging import Logger +from graphnet.training.utils import collate_fn class GraphNeTDataModule(pl.LightningDataModule, Logger): @@ -65,6 +66,7 @@ def __init__( self._validation_dataloader_kwargs = validation_dataloader_kwargs or {} self._test_dataloader_kwargs = test_dataloader_kwargs or {} + self._resolve_dataloader_kwargs() # If multiple dataset paths are given, we should use EnsembleDataset self._use_ensemble_dataset = isinstance( self._dataset_args["path"], list @@ -72,6 +74,14 @@ def __init__( self.setup("fit") + def _resolve_dataloader_kwargs(self) -> None: + if "collate_fn" not in self._train_dataloader_kwargs: + self._train_dataloader_kwargs["collate_fn"] = collate_fn + if "collate_fn" not in self._validation_dataloader_kwargs: + self._validation_dataloader_kwargs["collate_fn"] = collate_fn + if "collate_fn" not in self._test_dataloader_kwargs: + self._test_dataloader_kwargs["collate_fn"] = collate_fn + def prepare_data(self) -> None: """Prepare the dataset for training.""" # Download method for curated datasets. Method for download is From 3e5cafabd04011ffb2ebf9fe98509b55c7ea99f9 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 3 Apr 2024 20:43:08 +0200 Subject: [PATCH 2/6] make `GraphNeTDataModule` use `DataLoader` from graphnet --- src/graphnet/data/datamodule.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index 42c7bb7c5..bacf95bb9 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -1,7 +1,6 @@ """Base `Dataloader` class(es) used in `graphnet`.""" from typing import Dict, Any, Optional, List, Tuple, Union import pytorch_lightning as pl -from torch.utils.data import DataLoader from copy import deepcopy from sklearn.model_selection import train_test_split import pandas as pd @@ -13,7 +12,7 @@ ParquetDataset, ) from graphnet.utilities.logging import Logger -from graphnet.training.utils import collate_fn +from graphnet.data.dataloader import DataLoader class GraphNeTDataModule(pl.LightningDataModule, Logger): @@ -66,7 +65,6 @@ def __init__( self._validation_dataloader_kwargs = validation_dataloader_kwargs or {} self._test_dataloader_kwargs = test_dataloader_kwargs or {} - self._resolve_dataloader_kwargs() # If multiple dataset paths are given, we should use EnsembleDataset self._use_ensemble_dataset = isinstance( self._dataset_args["path"], list @@ -74,14 +72,6 @@ def __init__( self.setup("fit") - def _resolve_dataloader_kwargs(self) -> None: - if "collate_fn" not in self._train_dataloader_kwargs: - self._train_dataloader_kwargs["collate_fn"] = collate_fn - if "collate_fn" not in self._validation_dataloader_kwargs: - self._validation_dataloader_kwargs["collate_fn"] = collate_fn - if "collate_fn" not in self._test_dataloader_kwargs: - self._test_dataloader_kwargs["collate_fn"] = collate_fn - def prepare_data(self) -> None: """Prepare the dataset for training.""" # Download method for curated datasets. Method for download is From f429afc511ecc55aeb25acb3bc4e357345352598 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 3 Apr 2024 21:07:24 +0200 Subject: [PATCH 3/6] force multiprocessing context to "spawn" in `ParquetDataset` --- src/graphnet/data/dataset/parquet/parquet_dataset.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/graphnet/data/dataset/parquet/parquet_dataset.py b/src/graphnet/data/dataset/parquet/parquet_dataset.py index 3c3815a20..77fe882f9 100644 --- a/src/graphnet/data/dataset/parquet/parquet_dataset.py +++ b/src/graphnet/data/dataset/parquet/parquet_dataset.py @@ -31,6 +31,13 @@ from graphnet.exceptions.exceptions import ColumnMissingException +# Force spawn-method +try: + torch.multiprocessing.set_start_method("spawn") +except RuntimeError: + pass + + class ParquetDataset(Dataset): """Dataset class for Parquet-files converted with `ParquetWriter`.""" From f09f5dd7a86acf794ef3d6da3b63252d08a5a039 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 3 Apr 2024 21:42:19 +0200 Subject: [PATCH 4/6] solve hang from polars multiprocessing --- src/graphnet/data/datamodule.py | 1 - src/graphnet/data/dataset/parquet/parquet_dataset.py | 3 --- 2 files changed, 4 deletions(-) diff --git a/src/graphnet/data/datamodule.py b/src/graphnet/data/datamodule.py index bacf95bb9..f012c4abe 100644 --- a/src/graphnet/data/datamodule.py +++ b/src/graphnet/data/datamodule.py @@ -316,7 +316,6 @@ def _split_selection( Returns: Training selection, Validation selection. """ - print(selection) assert isinstance(selection, (int, list)) if isinstance(selection, int): flat_selection = [selection] diff --git a/src/graphnet/data/dataset/parquet/parquet_dataset.py b/src/graphnet/data/dataset/parquet/parquet_dataset.py index 77fe882f9..9e1e98db2 100644 --- a/src/graphnet/data/dataset/parquet/parquet_dataset.py +++ b/src/graphnet/data/dataset/parquet/parquet_dataset.py @@ -30,7 +30,6 @@ from graphnet.data.dataset import Dataset from graphnet.exceptions.exceptions import ColumnMissingException - # Force spawn-method try: torch.multiprocessing.set_start_method("spawn") @@ -202,7 +201,6 @@ def _calculate_sizes(self) -> List[int]: """Calculate the number of events in each batch.""" sizes = [] for batch_id in self._indices: - print(batch_id) path = os.path.join( self._path, self._truth_table, @@ -304,7 +302,6 @@ def _load_table(self, table_name: str, file_idx: int) -> None: file_path = os.path.join( self._path, table_name, f"{table_name}_{file_idx}.parquet" ) - print(file_idx) df = pol.read_parquet(file_path).sort(self._index_column) if (table_name in self._pulsemaps) or ( table_name == self._node_truth_table From 300e6c9675bde11840e6973f36865dce9c4d28ce Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Thu, 4 Apr 2024 08:43:28 +0200 Subject: [PATCH 5/6] move spawn method inside ParquetDataset --- src/graphnet/data/dataset/parquet/parquet_dataset.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/graphnet/data/dataset/parquet/parquet_dataset.py b/src/graphnet/data/dataset/parquet/parquet_dataset.py index 9e1e98db2..2e1024a08 100644 --- a/src/graphnet/data/dataset/parquet/parquet_dataset.py +++ b/src/graphnet/data/dataset/parquet/parquet_dataset.py @@ -30,12 +30,6 @@ from graphnet.data.dataset import Dataset from graphnet.exceptions.exceptions import ColumnMissingException -# Force spawn-method -try: - torch.multiprocessing.set_start_method("spawn") -except RuntimeError: - pass - class ParquetDataset(Dataset): """Dataset class for Parquet-files converted with `ParquetWriter`.""" @@ -290,6 +284,11 @@ def _query_table( ) data = df.select(columns) if isinstance(data[columns[0]][0], Series): + # Force spawn-method + try: + torch.multiprocessing.set_start_method("spawn") + except RuntimeError: + pass data = data.explode(columns) array = data.to_numpy() else: From 79b400b554fa28b13f84d7491fe67eb94382f092 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Thu, 4 Apr 2024 09:00:14 +0200 Subject: [PATCH 6/6] remove forced multiprocessing context - switch `.explode` to pure numpy --- src/graphnet/data/dataset/parquet/parquet_dataset.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/graphnet/data/dataset/parquet/parquet_dataset.py b/src/graphnet/data/dataset/parquet/parquet_dataset.py index 2e1024a08..0feec8da9 100644 --- a/src/graphnet/data/dataset/parquet/parquet_dataset.py +++ b/src/graphnet/data/dataset/parquet/parquet_dataset.py @@ -284,13 +284,10 @@ def _query_table( ) data = df.select(columns) if isinstance(data[columns[0]][0], Series): - # Force spawn-method - try: - torch.multiprocessing.set_start_method("spawn") - except RuntimeError: - pass - data = data.explode(columns) - array = data.to_numpy() + x = [data[col][0].to_numpy().reshape(-1, 1) for col in columns] + array = np.concatenate(x, axis=1) + else: + array = data.to_numpy() else: array = np.array() return array