From 73e3bf67fec89a06d80916c9621145dc4bb77738 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Tue, 30 Apr 2024 10:29:57 +0200 Subject: [PATCH 01/29] add trident --- src/graphnet/datasets/__init__.py | 1 + src/graphnet/datasets/trident.py | 67 +++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) create mode 100644 src/graphnet/datasets/trident.py diff --git a/src/graphnet/datasets/__init__.py b/src/graphnet/datasets/__init__.py index fc2e81be2..25796f769 100644 --- a/src/graphnet/datasets/__init__.py +++ b/src/graphnet/datasets/__init__.py @@ -1,2 +1,3 @@ """Contains pre-converted datasets ready for training.""" from .test_dataset import TestDataset +from .trident import TRIDENTSmall diff --git a/src/graphnet/datasets/trident.py b/src/graphnet/datasets/trident.py new file mode 100644 index 000000000..438a1f9c9 --- /dev/null +++ b/src/graphnet/datasets/trident.py @@ -0,0 +1,67 @@ +"""A CuratedDataset for unit tests.""" +from typing import Dict, Any, List, Tuple, Union +import os + +from graphnet.data import ERDAHostedDataset +from graphnet.data.constants import FEATURES + + +class TRIDENTSmall(ERDAHostedDataset): + """A Dataset for Prometheus simulation of TRIDENT. + + Small version with ~ 1 mill track events. + """ + + # Static Member Variables: + _pulsemaps = ["photons"] + _truth_table = "mc_truth" + _event_truth = [ + "interaction", + "initial_state_energy", + "initial_state_type", + "initial_state_zenith", + "initial_state_azimuth", + "initial_state_x", + "initial_state_y", + "initial_state_z", + ] + _pulse_truth = None + _features = FEATURES.PROMETHEUS + _experiment = "TRIDENT Prometheus Simulation" + _creator = "Rasmus F. Ørsøe" + _comments = ( + "Contains ~1 million track events." + " Simulation produced by Stephan Meighen-Berger, " + "U. Melbourne." + ) + _available_backends = ["sqlite", "parquet"] + _file_hashes = {"sqlite": "F2R8qb8JW7", "parquet": "BRRgyslRno"} + _citation = None + + def _prepare_args( + self, backend: str, features: List[str], truth: List[str] + ) -> Tuple[Dict[str, Any], Union[List[int], None], Union[List[int], None]]: + """Prepare arguments for dataset. + + Args: + backend: backend of dataset. Either "parquet" or "sqlite" + features: List of features from user to use as input. + truth: List of event-level truth form user. + + Returns: Dataset arguments and selections + """ + if backend == "sqlite": + dataset_path = os.path.join(self.dataset_dir, "merged.db") + elif backend == "parquet": + dataset_path = self.dataset_dir + + dataset_args = { + "truth_table": self._truth_table, + "pulsemaps": self._pulsemaps, + "path": dataset_path, + "graph_definition": self._graph_definition, + "features": features, + "truth": truth, + } + + return dataset_args, None, None From d052d88b232da0cf39dd5104cd4ac7d7f98eaff1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Tue, 30 Apr 2024 12:13:49 +0200 Subject: [PATCH 02/29] add selections --- src/graphnet/datasets/trident.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/graphnet/datasets/trident.py b/src/graphnet/datasets/trident.py index 438a1f9c9..66e65c0ac 100644 --- a/src/graphnet/datasets/trident.py +++ b/src/graphnet/datasets/trident.py @@ -4,6 +4,8 @@ from graphnet.data import ERDAHostedDataset from graphnet.data.constants import FEATURES +from graphnet.data.utilities import query_database +from sklearn.model_selection import train_test_split class TRIDENTSmall(ERDAHostedDataset): @@ -34,8 +36,8 @@ class TRIDENTSmall(ERDAHostedDataset): " Simulation produced by Stephan Meighen-Berger, " "U. Melbourne." ) - _available_backends = ["sqlite", "parquet"] - _file_hashes = {"sqlite": "F2R8qb8JW7", "parquet": "BRRgyslRno"} + _available_backends = ["sqlite"] + _file_hashes = {"sqlite": "F2R8qb8JW7"} _citation = None def _prepare_args( @@ -50,11 +52,18 @@ def _prepare_args( Returns: Dataset arguments and selections """ - if backend == "sqlite": - dataset_path = os.path.join(self.dataset_dir, "merged.db") - elif backend == "parquet": - dataset_path = self.dataset_dir + dataset_path = os.path.join(self.dataset_dir, "merged.db") + event_nos = query_database( + database=dataset_path, query="SELECT event_no FROM mc_truth" + ) + + train_val, test = train_test_split( + event_nos["event_no"].tolist(), + test_size=0.10, + random_state=42, + shuffle=True, + ) dataset_args = { "truth_table": self._truth_table, "pulsemaps": self._pulsemaps, @@ -64,4 +73,4 @@ def _prepare_args( "truth": truth, } - return dataset_args, None, None + return dataset_args, train_val, test From f85b7142cd883cccabd75404a2dde7f25196fbb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Tue, 30 Apr 2024 12:26:28 +0200 Subject: [PATCH 03/29] add `Direction` label to dataset --- src/graphnet/datasets/trident.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/graphnet/datasets/trident.py b/src/graphnet/datasets/trident.py index 66e65c0ac..8b912e15a 100644 --- a/src/graphnet/datasets/trident.py +++ b/src/graphnet/datasets/trident.py @@ -1,11 +1,12 @@ """A CuratedDataset for unit tests.""" from typing import Dict, Any, List, Tuple, Union import os +from sklearn.model_selection import train_test_split +from graphnet.training.labels import Direction from graphnet.data import ERDAHostedDataset from graphnet.data.constants import FEATURES from graphnet.data.utilities import query_database -from sklearn.model_selection import train_test_split class TRIDENTSmall(ERDAHostedDataset): @@ -71,6 +72,12 @@ def _prepare_args( "graph_definition": self._graph_definition, "features": features, "truth": truth, + "labels": [ + Direction( + azimuth_key="initial_state_azimuth", + zenith_key="initial_state_zenith", + ) + ], } return dataset_args, train_val, test From 1acd7eb13196322c2a1997c96761f70db8d130a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Tue, 30 Apr 2024 12:28:14 +0200 Subject: [PATCH 04/29] fix direction label --- src/graphnet/datasets/trident.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/graphnet/datasets/trident.py b/src/graphnet/datasets/trident.py index 8b912e15a..04b9dd14b 100644 --- a/src/graphnet/datasets/trident.py +++ b/src/graphnet/datasets/trident.py @@ -72,12 +72,12 @@ def _prepare_args( "graph_definition": self._graph_definition, "features": features, "truth": truth, - "labels": [ - Direction( + "labels": { + "direction": Direction( azimuth_key="initial_state_azimuth", zenith_key="initial_state_zenith", ) - ], + }, } return dataset_args, train_val, test From 736f564aa32e4906fba02f4f19e9b9cc5f58c072 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Tue, 30 Apr 2024 12:34:47 +0200 Subject: [PATCH 05/29] add Track Label --- src/graphnet/datasets/trident.py | 7 +++++-- src/graphnet/training/labels.py | 36 ++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/graphnet/datasets/trident.py b/src/graphnet/datasets/trident.py index 04b9dd14b..0076bef95 100644 --- a/src/graphnet/datasets/trident.py +++ b/src/graphnet/datasets/trident.py @@ -3,7 +3,7 @@ import os from sklearn.model_selection import train_test_split -from graphnet.training.labels import Direction +from graphnet.training.labels import Direction, Track from graphnet.data import ERDAHostedDataset from graphnet.data.constants import FEATURES from graphnet.data.utilities import query_database @@ -76,7 +76,10 @@ def _prepare_args( "direction": Direction( azimuth_key="initial_state_azimuth", zenith_key="initial_state_zenith", - ) + ), + "track": Track( + pid_key="initial_state_type", interaction_key="interaction" + ), }, } diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index 1a17ab16a..bacaa16ee 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -68,3 +68,39 @@ def __call__(self, graph: Data) -> torch.tensor: ).reshape(-1, 1) z = torch.cos(graph[self._zenith_key]).reshape(-1, 1) return torch.cat((x, y, z), dim=1) + + +class Track(Label): + """Class for producing NuMuCC label. + + Label is set to `1` if the event is a NuMu CC event, else `0`. + """ + + def __init__( + self, + key: str = "track", + pid_key: str = "pid", + interaction_key: str = "interaction_type", + ): + """Construct `Track` label. + + Args: + key: The name of the field in `Data` where the label will be + stored. That is, `graph[key] = label`. + pid_key: The name of the pre-existing key in `graph` that will + be used to access the pdg encoding, used when calculating + the direction. + interaction_key: The name of the pre-existing key in `graph` that will + be used to access the interaction type (1 denoting CC), + used when calculating the direction. + """ + self._pid_key = pid_key + self._int_key = interaction_key + + # Base class constructor + super().__init__(key=key) + + def __call__(self, graph: Data) -> torch.tensor: + """Compute label for `graph`.""" + label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1) + return torch.tensor(label, dtype=torch.int64) From 9464f7a04fcf2f8ce79291bba509160ee9e8fa10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Tue, 30 Apr 2024 12:44:03 +0200 Subject: [PATCH 06/29] refactor prometheus dataset --- src/graphnet/datasets/__init__.py | 2 +- src/graphnet/datasets/prometheus_datasets.py | 91 ++++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 src/graphnet/datasets/prometheus_datasets.py diff --git a/src/graphnet/datasets/__init__.py b/src/graphnet/datasets/__init__.py index 25796f769..199df5310 100644 --- a/src/graphnet/datasets/__init__.py +++ b/src/graphnet/datasets/__init__.py @@ -1,3 +1,3 @@ """Contains pre-converted datasets ready for training.""" from .test_dataset import TestDataset -from .trident import TRIDENTSmall +from .prometheus_datasets import TRIDENTSmall diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py new file mode 100644 index 000000000..b5a01ae4a --- /dev/null +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -0,0 +1,91 @@ +"""A CuratedDataset for unit tests.""" +from typing import Dict, Any, List, Tuple, Union +import os +from sklearn.model_selection import train_test_split + +from graphnet.training.labels import Direction, Track +from graphnet.data import ERDAHostedDataset +from graphnet.data.constants import FEATURES +from graphnet.data.utilities import query_database + + +class PublicPrometheusDataset(ERDAHostedDataset): + """A generic class for public Prometheus Datasets hosted using ERDA.""" + + # Static Member Variables: + _pulsemaps = ["photons"] + _truth_table = "mc_truth" + _event_truth = [ + "interaction", + "initial_state_energy", + "initial_state_type", + "initial_state_zenith", + "initial_state_azimuth", + "initial_state_x", + "initial_state_y", + "initial_state_z", + ] + _pulse_truth = None + _features = FEATURES.PROMETHEUS + + def _prepare_args( + self, backend: str, features: List[str], truth: List[str] + ) -> Tuple[Dict[str, Any], Union[List[int], None], Union[List[int], None]]: + """Prepare arguments for dataset. + + Args: + backend: backend of dataset. Either "parquet" or "sqlite" + features: List of features from user to use as input. + truth: List of event-level truth form user. + + Returns: Dataset arguments and selections + """ + dataset_path = os.path.join(self.dataset_dir, "merged.db") + + event_nos = query_database( + database=dataset_path, query="SELECT event_no FROM mc_truth" + ) + + train_val, test = train_test_split( + event_nos["event_no"].tolist(), + test_size=0.10, + random_state=42, + shuffle=True, + ) + dataset_args = { + "truth_table": self._truth_table, + "pulsemaps": self._pulsemaps, + "path": dataset_path, + "graph_definition": self._graph_definition, + "features": features, + "truth": truth, + "labels": { + "direction": Direction( + azimuth_key="initial_state_azimuth", + zenith_key="initial_state_zenith", + ), + "track": Track( + pid_key="initial_state_type", interaction_key="interaction" + ), + }, + } + + return dataset_args, train_val, test + + +class TRIDENTSmall(PublicPrometheusDataset): + """Public Dataset for Prometheus simulation of a TRIDENT geometry. + + Contains ~ 1 million track events. + """ + + _experiment = "TRIDENT Prometheus Simulation" + _creator = "Rasmus F. Ørsøe" + _comments = ( + "Contains ~1 million track events." + " Simulation produced by Stephan Meighen-Berger, " + "U. Melbourne." + ) + _available_backends = ["sqlite"] + _file_hashes = {"sqlite": "F2R8qb8JW7"} + _citation = None From f51307269a434f952cf058d3f0234627fa65f403 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Tue, 30 Apr 2024 12:45:37 +0200 Subject: [PATCH 07/29] update track label --- src/graphnet/training/labels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index bacaa16ee..4a0debff9 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -103,4 +103,4 @@ def __init__( def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1) - return torch.tensor(label, dtype=torch.int64) + return label # torch.tensor(label, dtype=torch.int64) From 2a58b58b3590025d2966a78eb0c9d5417270a485 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Tue, 30 Apr 2024 12:47:04 +0200 Subject: [PATCH 08/29] cast bool to int track label --- src/graphnet/training/labels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index 4a0debff9..478bbd038 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -103,4 +103,4 @@ def __init__( def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1) - return label # torch.tensor(label, dtype=torch.int64) + return label.type(torch.int) # torch.tensor(label, dtype=torch.int64) From dd3b1ffb612728edfd756e3857294d568c758c9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Tue, 30 Apr 2024 12:50:41 +0200 Subject: [PATCH 09/29] delete redundant file --- src/graphnet/datasets/trident.py | 86 -------------------------------- 1 file changed, 86 deletions(-) delete mode 100644 src/graphnet/datasets/trident.py diff --git a/src/graphnet/datasets/trident.py b/src/graphnet/datasets/trident.py deleted file mode 100644 index 0076bef95..000000000 --- a/src/graphnet/datasets/trident.py +++ /dev/null @@ -1,86 +0,0 @@ -"""A CuratedDataset for unit tests.""" -from typing import Dict, Any, List, Tuple, Union -import os -from sklearn.model_selection import train_test_split - -from graphnet.training.labels import Direction, Track -from graphnet.data import ERDAHostedDataset -from graphnet.data.constants import FEATURES -from graphnet.data.utilities import query_database - - -class TRIDENTSmall(ERDAHostedDataset): - """A Dataset for Prometheus simulation of TRIDENT. - - Small version with ~ 1 mill track events. - """ - - # Static Member Variables: - _pulsemaps = ["photons"] - _truth_table = "mc_truth" - _event_truth = [ - "interaction", - "initial_state_energy", - "initial_state_type", - "initial_state_zenith", - "initial_state_azimuth", - "initial_state_x", - "initial_state_y", - "initial_state_z", - ] - _pulse_truth = None - _features = FEATURES.PROMETHEUS - _experiment = "TRIDENT Prometheus Simulation" - _creator = "Rasmus F. Ørsøe" - _comments = ( - "Contains ~1 million track events." - " Simulation produced by Stephan Meighen-Berger, " - "U. Melbourne." - ) - _available_backends = ["sqlite"] - _file_hashes = {"sqlite": "F2R8qb8JW7"} - _citation = None - - def _prepare_args( - self, backend: str, features: List[str], truth: List[str] - ) -> Tuple[Dict[str, Any], Union[List[int], None], Union[List[int], None]]: - """Prepare arguments for dataset. - - Args: - backend: backend of dataset. Either "parquet" or "sqlite" - features: List of features from user to use as input. - truth: List of event-level truth form user. - - Returns: Dataset arguments and selections - """ - dataset_path = os.path.join(self.dataset_dir, "merged.db") - - event_nos = query_database( - database=dataset_path, query="SELECT event_no FROM mc_truth" - ) - - train_val, test = train_test_split( - event_nos["event_no"].tolist(), - test_size=0.10, - random_state=42, - shuffle=True, - ) - dataset_args = { - "truth_table": self._truth_table, - "pulsemaps": self._pulsemaps, - "path": dataset_path, - "graph_definition": self._graph_definition, - "features": features, - "truth": truth, - "labels": { - "direction": Direction( - azimuth_key="initial_state_azimuth", - zenith_key="initial_state_zenith", - ), - "track": Track( - pid_key="initial_state_type", interaction_key="interaction" - ), - }, - } - - return dataset_args, train_val, test From 8e7b16544b68d494f6a221f16a2b5e554c195e6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Tue, 30 Apr 2024 13:20:49 +0200 Subject: [PATCH 10/29] add PONESmall Dataset --- src/graphnet/datasets/prometheus_datasets.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index b5a01ae4a..b1f64ea98 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -76,7 +76,7 @@ def _prepare_args( class TRIDENTSmall(PublicPrometheusDataset): """Public Dataset for Prometheus simulation of a TRIDENT geometry. - Contains ~ 1 million track events. + Contains ~ 1 million track events between 10 GeV - 10 TeV. """ _experiment = "TRIDENT Prometheus Simulation" @@ -89,3 +89,21 @@ class TRIDENTSmall(PublicPrometheusDataset): _available_backends = ["sqlite"] _file_hashes = {"sqlite": "F2R8qb8JW7"} _citation = None + + +class PONESmall(PublicPrometheusDataset): + """Public Dataset for Prometheus simulation of a P-ONE geometry. + + Contains ~ 1 million track events between 10 GeV - 10 TeV. + """ + + _experiment = "P-ONE Prometheus Simulation" + _creator = "Rasmus F. Ørsøe" + _comments = ( + "Contains ~1 million track events." + " Simulation produced by Stephan Meighen-Berger, " + "U. Melbourne." + ) + _available_backends = ["sqlite"] + _file_hashes = {"sqlite": "e9ZSVMiykD"} + _citation = None From e3decea2109b07c578e64b6bba19dfc3bf51db27 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 1 May 2024 10:03:39 +0200 Subject: [PATCH 11/29] add baikal-gvd --- src/graphnet/datasets/prometheus_datasets.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index b1f64ea98..e9d269f2a 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -107,3 +107,20 @@ class PONESmall(PublicPrometheusDataset): _available_backends = ["sqlite"] _file_hashes = {"sqlite": "e9ZSVMiykD"} _citation = None + +class BaikailGVDSmall(PublicPrometheusDataset): + """Public Dataset for Prometheus simulation of a Baikal-GVD geometry. + + Contains ~ 1 million track events between 10 GeV - 10 TeV. + """ + + _experiment = "Baikal-GVD Prometheus Simulation" + _creator = "Rasmus F. Ørsøe" + _comments = ( + "Contains ~1 million track events." + " Simulation produced by Stephan Meighen-Berger, " + "U. Melbourne." + ) + _available_backends = ["sqlite"] + _file_hashes = {"sqlite": "ebLJHjPDqy"} + _citation = None From b348c1ee8dc9353700073604a584ccb2f60d04e6 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 1 May 2024 10:10:00 +0200 Subject: [PATCH 12/29] remove comment in Track label --- src/graphnet/training/labels.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index 478bbd038..14d84236d 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -103,4 +103,4 @@ def __init__( def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1) - return label.type(torch.int) # torch.tensor(label, dtype=torch.int64) + return label.type(torch.int) \ No newline at end of file From b42871713160d403f089e4061b2a8ea351fc6f6c Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 1 May 2024 10:13:48 +0200 Subject: [PATCH 13/29] codeclimate --- src/graphnet/training/labels.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index 14d84236d..4d33bdacd 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -90,9 +90,9 @@ def __init__( pid_key: The name of the pre-existing key in `graph` that will be used to access the pdg encoding, used when calculating the direction. - interaction_key: The name of the pre-existing key in `graph` that will - be used to access the interaction type (1 denoting CC), - used when calculating the direction. + interaction_key: The name of the pre-existing key in `graph` that + will be used to access the interaction type (1 denoting CC), + used when calculating the direction. """ self._pid_key = pid_key self._int_key = interaction_key @@ -103,4 +103,5 @@ def __init__( def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1) - return label.type(torch.int) \ No newline at end of file + return label.type(torch.int) + \ No newline at end of file From b1232c1401f67bf4d75929f749949759b8c866ae Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 1 May 2024 10:14:18 +0200 Subject: [PATCH 14/29] codeclimate --- src/graphnet/datasets/prometheus_datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index e9d269f2a..910362be3 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -108,6 +108,7 @@ class PONESmall(PublicPrometheusDataset): _file_hashes = {"sqlite": "e9ZSVMiykD"} _citation = None + class BaikailGVDSmall(PublicPrometheusDataset): """Public Dataset for Prometheus simulation of a Baikal-GVD geometry. From 1ddcc0a700090189f1d5754308cf8b3efa5a07c3 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 1 May 2024 10:15:18 +0200 Subject: [PATCH 15/29] update module docstring --- src/graphnet/datasets/prometheus_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index 910362be3..6ddd2f09a 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -1,4 +1,4 @@ -"""A CuratedDataset for unit tests.""" +"""Public datasets from Prometheus Simulation.""" from typing import Dict, Any, List, Tuple, Union import os from sklearn.model_selection import train_test_split From fd9eccc9fb747ee391d258556ba0026b23272877 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 1 May 2024 10:30:01 +0200 Subject: [PATCH 16/29] add pone geometry table --- .../prometheus/pone_triangle.parquet | Bin 0 -> 7481 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 data/geometry_tables/prometheus/pone_triangle.parquet diff --git a/data/geometry_tables/prometheus/pone_triangle.parquet b/data/geometry_tables/prometheus/pone_triangle.parquet new file mode 100644 index 0000000000000000000000000000000000000000..88018023cdf7146a65aa02df8ae06c1b921fca73 GIT binary patch literal 7481 zcmeHMU2I!d9X~f;P1Xq`udVebald zU4O*auBweORn;DNKvRXL3Q<54j4B(52S5S&1!;;Vc$r`i6HK6rNgy#zuxI|~+-v(f zw%ZLgdEzQR-{bTDpa1tc|C3~yU2)KTw4I(@pa*Dl2q7~<{Zt=U>)&TXoCoO-jp%Xu z_#~~;qrf*XVII^7oTo9WnUJl+Ebxukd*D;0dThu>*?O3ds0n)im)Y!bXcbPs}$QFtK)-!1fs1w0L2jb8X*>1WT=f4)%D z-U}XJG|}TmkLncjV4W9x)IJNs57;!1sAceox!vE{p9=BkUCQ(x#r(UM`9&{y)Gn@z zl8zqLdg!0`EYLRM)AuQaDC*;q=f&DLe3g+0|G827&8<(yw|}bEj%{DFy_G1{sv{3v zuN}?T-oCeW>yy{znq+(H-qt|6_6)pk<*(M>1-5%zmujQH_NV3ZHKe^g?`ZQLOa(rV zzRq)BljBiG9#GLi&-iwKrq8oAq$1DAZHos2gOKM%9NMdN(iqx0MK2bxxKS*?-<$8up2kLMJ}{8u1PM2CqAA5mfgNp_dxEfWN+R&`Tg;=N8#~ zqUqA8zC{Yfyr`sC#bkY({E=erj4$&@i{)QCh=%TT3>y$>a5RPm6EC@LFfmXN!2R)YI*~G;B4<;A>c47I% z;^ou}(KqDh|9?*P(&}^jEX)j9^m9-}%xhy! zC#WK+(O*)g8pXU~X5Ks6tWoqxxgCw)+ZMGs(cFKOu`k^@#{7P~;ml`fe;u9KtIo9b zl)(*SJNof)=A$RUjsLeJsN%=mj@~$F2)eNcz4J8lsij^!|7#B#yYE8gfr5}p$qD)p zqW7QdX}v_XPgUBds@harYk_MM$ZMWK+IOm`u~Vx4)+JI^WKzc_#_sOTm2R&tT6=Y& zVD_-f?@5o@)gHj?=9$dw=1Ixy=80h1oG_`(py{vXBM6O6ng?kIIy(N+m#im=?UDhW9S%t}d;fMZQeNy*jLIZefC zv691#){-I$*<=b!RZ4QM(ja8{^^{P=^Kz<~NG6I2V83BaDZma_S-kejW&f#j)|)q* z#3NoAF9PuM9LvtI-ZblH*);1* zlR2-qg7vg7wU?^GRp#I=?HIq_&l5-Qvo z&R+yMt_&Z!g4dRk5w(EV#$`$311oF1>@UVI0bju7-~(kWN4F!imYkQ$gysu6cv%AZ zD-tJ`(;&~mN4>7-^omOeV7aWwJA>Ebk&s=ZJEUa555BBqxbav~Bg#Qp9^@pX#1%m>_=b0oz5qyd_!^L`hMNBdP#}o?CEemWXYMc zYv-1J-U@OU&NFvKKc5SUNU#!%23_R*Q@k!1SNLgx*pigP6~QTO@FG6DIW=@?ro6Du ziuerVnz?hvqrqw-;w}*x>s)D3oGEjP9eUd}fX|M5A&=YQ>_upUvokQsmhcCM`=M9B zL+F0+;HK$gr#7ahVBlp{JjY(aAF#0gkYNfh4e-xSO`U@Q$6jAPJ2i#d?ZrjXd2u6d zwrYJiPwZLJ<@?ystiFD3mNacIS*rDKvK|k;#NOhDE}A ztR~raS^qxgBT9Tb^fkWokbfCBe7a3zM{+LLC#ip-W=B7eP Date: Wed, 1 May 2024 10:34:27 +0200 Subject: [PATCH 17/29] add P-ONE detector class for triangle geometry --- src/graphnet/models/detector/prometheus.py | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/graphnet/models/detector/prometheus.py b/src/graphnet/models/detector/prometheus.py index 1b93c144d..77cb604e4 100644 --- a/src/graphnet/models/detector/prometheus.py +++ b/src/graphnet/models/detector/prometheus.py @@ -333,6 +333,33 @@ def _sensor_pos_xyz(self, x: torch.tensor) -> torch.tensor: def _t(self, x: torch.tensor) -> torch.tensor: return x / 1.05e04 + +class PONETriangle(Detector): + """`Detector` class for Prometheus PONE Triangle""" + + geometry_table_path = os.path.join( + PROMETHEUS_GEOMETRY_TABLE_DIR, "pone_triangle.parquet" + ) + xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"] + string_id_column = "sensor_string_id" + sensor_id_column = "sensor_id" + + def feature_map(self) -> Dict[str, Callable]: + """Map standardization functions to each dimension.""" + feature_map = { + "sensor_pos_x": self._sensor_pos_xyz, + "sensor_pos_y": self._sensor_pos_xyz, + "sensor_pos_z": self._sensor_pos_xyz, + "t": self._t, + } + return feature_map + + def _sensor_pos_xyz(self, x: torch.tensor) -> torch.tensor: + return x / 100 + + def _t(self, x: torch.tensor) -> torch.tensor: + return x / 1.05e04 + class Prometheus(ORCA150SuperDense): From 40ec9097ca56a3e2d0fe40081bed36cd56ea1711 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Wed, 1 May 2024 10:38:05 +0200 Subject: [PATCH 18/29] code quality --- src/graphnet/models/detector/prometheus.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/graphnet/models/detector/prometheus.py b/src/graphnet/models/detector/prometheus.py index 77cb604e4..05f68fc41 100644 --- a/src/graphnet/models/detector/prometheus.py +++ b/src/graphnet/models/detector/prometheus.py @@ -333,9 +333,10 @@ def _sensor_pos_xyz(self, x: torch.tensor) -> torch.tensor: def _t(self, x: torch.tensor) -> torch.tensor: return x / 1.05e04 - + + class PONETriangle(Detector): - """`Detector` class for Prometheus PONE Triangle""" + """`Detector` class for Prometheus PONE Triangle.""" geometry_table_path = os.path.join( PROMETHEUS_GEOMETRY_TABLE_DIR, "pone_triangle.parquet" @@ -361,6 +362,5 @@ def _t(self, x: torch.tensor) -> torch.tensor: return x / 1.05e04 - class Prometheus(ORCA150SuperDense): """Reference to ORCA150SuperDense.""" From 1e8f2da0cd1c2a68582677b6d8e15653d041e41e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Wed, 1 May 2024 10:38:54 +0200 Subject: [PATCH 19/29] typo --- src/graphnet/datasets/prometheus_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index 6ddd2f09a..80dda0a82 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -36,7 +36,7 @@ def _prepare_args( Args: backend: backend of dataset. Either "parquet" or "sqlite" features: List of features from user to use as input. - truth: List of event-level truth form user. + truth: List of event-level truth variables from user. Returns: Dataset arguments and selections """ From 2290183a6821b66eb90031967ab11003367b0ef4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Wed, 1 May 2024 10:44:41 +0200 Subject: [PATCH 20/29] code quality --- src/graphnet/training/labels.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index 4d33bdacd..3471fbcb4 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -90,7 +90,7 @@ def __init__( pid_key: The name of the pre-existing key in `graph` that will be used to access the pdg encoding, used when calculating the direction. - interaction_key: The name of the pre-existing key in `graph` that + interaction_key: The name of the pre-existing key in `graph` that will be used to access the interaction type (1 denoting CC), used when calculating the direction. """ @@ -104,4 +104,3 @@ def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1) return label.type(torch.int) - \ No newline at end of file From 55b16980a4ad9a55d1274edcc39ea4997d673aba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Wed, 1 May 2024 10:45:55 +0200 Subject: [PATCH 21/29] code quality --- src/graphnet/training/labels.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index 3471fbcb4..11129f915 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -91,8 +91,8 @@ def __init__( be used to access the pdg encoding, used when calculating the direction. interaction_key: The name of the pre-existing key in `graph` that - will be used to access the interaction type (1 denoting CC), - used when calculating the direction. + will be used to access the interaction type (1 denoting CC), + used when calculating the direction. """ self._pid_key = pid_key self._int_key = interaction_key From 535b25cb0eccd03cf755e38ed30ee98b31a50ff8 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 1 May 2024 11:03:31 +0200 Subject: [PATCH 22/29] add imports --- src/graphnet/datasets/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/graphnet/datasets/__init__.py b/src/graphnet/datasets/__init__.py index 199df5310..6d4c4bdb1 100644 --- a/src/graphnet/datasets/__init__.py +++ b/src/graphnet/datasets/__init__.py @@ -1,3 +1,3 @@ """Contains pre-converted datasets ready for training.""" from .test_dataset import TestDataset -from .prometheus_datasets import TRIDENTSmall +from .prometheus_datasets import TRIDENTSmall, BaikailGVDSmall, PONESmall From 1c305cb3a505b1fbc058024276b43340a3d8aa11 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 1 May 2024 11:04:05 +0200 Subject: [PATCH 23/29] typo --- src/graphnet/datasets/__init__.py | 2 +- src/graphnet/datasets/prometheus_datasets.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/graphnet/datasets/__init__.py b/src/graphnet/datasets/__init__.py index 6d4c4bdb1..74893bb83 100644 --- a/src/graphnet/datasets/__init__.py +++ b/src/graphnet/datasets/__init__.py @@ -1,3 +1,3 @@ """Contains pre-converted datasets ready for training.""" from .test_dataset import TestDataset -from .prometheus_datasets import TRIDENTSmall, BaikailGVDSmall, PONESmall +from .prometheus_datasets import TRIDENTSmall, BaikalGVDSmall, PONESmall diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index 80dda0a82..0f8d37620 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -109,7 +109,7 @@ class PONESmall(PublicPrometheusDataset): _citation = None -class BaikailGVDSmall(PublicPrometheusDataset): +class BaikalGVDSmall(PublicPrometheusDataset): """Public Dataset for Prometheus simulation of a Baikal-GVD geometry. Contains ~ 1 million track events between 10 GeV - 10 TeV. From 82ee5d1e988a74a1b4b0bbef7afbab8b4fee9560 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Wed, 1 May 2024 15:55:48 +0200 Subject: [PATCH 24/29] update sharelinks --- src/graphnet/datasets/prometheus_datasets.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index 0f8d37620..8355a96b7 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -43,9 +43,10 @@ def _prepare_args( dataset_path = os.path.join(self.dataset_dir, "merged.db") event_nos = query_database( - database=dataset_path, query="SELECT event_no FROM mc_truth" + database=dataset_path, + query="SELECT event_no FROM mc_truth", ) - + event_nos = event_nos.loc[event_nos["n_photons"] < 1000] train_val, test = train_test_split( event_nos["event_no"].tolist(), test_size=0.10, @@ -87,7 +88,7 @@ class TRIDENTSmall(PublicPrometheusDataset): "U. Melbourne." ) _available_backends = ["sqlite"] - _file_hashes = {"sqlite": "F2R8qb8JW7"} + _file_hashes = {"sqlite": "aooZEpVsAM"} _citation = None @@ -105,7 +106,7 @@ class PONESmall(PublicPrometheusDataset): "U. Melbourne." ) _available_backends = ["sqlite"] - _file_hashes = {"sqlite": "e9ZSVMiykD"} + _file_hashes = {"sqlite": "GIt0hlG9qI"} _citation = None @@ -123,5 +124,5 @@ class BaikalGVDSmall(PublicPrometheusDataset): "U. Melbourne." ) _available_backends = ["sqlite"] - _file_hashes = {"sqlite": "ebLJHjPDqy"} + _file_hashes = {"sqlite": "FtFs5fxXB7"} _citation = None From 3c9457b96c9812e0ddcd7181ee9e753f1602c69b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Wed, 1 May 2024 15:56:12 +0200 Subject: [PATCH 25/29] update selection logic --- src/graphnet/datasets/prometheus_datasets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index 8355a96b7..8e6cbff05 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -46,7 +46,6 @@ def _prepare_args( database=dataset_path, query="SELECT event_no FROM mc_truth", ) - event_nos = event_nos.loc[event_nos["n_photons"] < 1000] train_val, test = train_test_split( event_nos["event_no"].tolist(), test_size=0.10, From b8c24fb1a92b9848c5a359f1482a7948e23435ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Wed, 1 May 2024 16:46:13 +0200 Subject: [PATCH 26/29] change dataset path logic --- src/graphnet/datasets/prometheus_datasets.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index 8e6cbff05..8f6cdc722 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -2,6 +2,7 @@ from typing import Dict, Any, List, Tuple, Union import os from sklearn.model_selection import train_test_split +from glob import glob from graphnet.training.labels import Direction, Track from graphnet.data import ERDAHostedDataset @@ -40,8 +41,9 @@ def _prepare_args( Returns: Dataset arguments and selections """ - dataset_path = os.path.join(self.dataset_dir, "merged.db") - + dataset_paths = glob(os.path.join(self.dataset_dir, "*.db")) + assert len(dataset_paths) == 1 + dataset_path = dataset_paths[0] event_nos = query_database( database=dataset_path, query="SELECT event_no FROM mc_truth", From 8af4e380d5ba0de5b6b7f2bdf1c60cb813dff854 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 1 May 2024 18:12:44 +0200 Subject: [PATCH 27/29] use truth table in query of PublicPrometheusDataset --- src/graphnet/datasets/prometheus_datasets.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index 0f8d37620..59ff27e08 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -43,7 +43,8 @@ def _prepare_args( dataset_path = os.path.join(self.dataset_dir, "merged.db") event_nos = query_database( - database=dataset_path, query="SELECT event_no FROM mc_truth" + database=dataset_path, + query=f"SELECT event_no FROM {self._truth_table[0]}" ) train_val, test = train_test_split( From 6840457bad90759376156078270f6cb8239769f6 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Wed, 1 May 2024 18:18:39 +0200 Subject: [PATCH 28/29] arturo comments --- src/graphnet/datasets/prometheus_datasets.py | 39 +++++++++++++------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index 338b7df5c..2248403c7 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -3,6 +3,7 @@ import os from sklearn.model_selection import train_test_split from glob import glob +import numpy as np from graphnet.training.labels import Direction, Track from graphnet.data import ERDAHostedDataset @@ -39,21 +40,31 @@ def _prepare_args( features: List of features from user to use as input. truth: List of event-level truth variables from user. - Returns: Dataset arguments and selections + Returns: Dataset arguments, train/val selection, test selection """ - dataset_paths = glob(os.path.join(self.dataset_dir, "*.db")) - assert len(dataset_paths) == 1 - dataset_path = dataset_paths[0] - event_nos = query_database( - database=dataset_path, - query=f"SELECT event_no FROM {self._truth_table[0]}" - ) - train_val, test = train_test_split( - event_nos["event_no"].tolist(), - test_size=0.10, - random_state=42, - shuffle=True, - ) + if backend == 'sqlite': + dataset_paths = glob(os.path.join(self.dataset_dir, "*.db")) + assert len(dataset_paths) == 1 + dataset_path = dataset_paths[0] + event_nos = query_database( + database=dataset_path, + query=f"SELECT event_no FROM {self._truth_table[0]}" + ) + train_val, test = train_test_split( + event_nos["event_no"].tolist(), + test_size=0.10, + random_state=42, + shuffle=True, + ) + elif backend == 'parquet': + dataset_path = self.dataset_dir + n_batches = len(glob(os.path.join(dataset_path,self._truth_table,'*.parquet'))) + train_val, test = train_test_split( + np.arange(0, n_batches), + test_size=0.10, + random_state=42, + shuffle=True, + ) dataset_args = { "truth_table": self._truth_table, "pulsemaps": self._pulsemaps, From 71b56a4feeb6baf5d20786c39a4fd74b0bb25f71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20=C3=98rs=C3=B8e?= Date: Wed, 1 May 2024 18:46:15 +0200 Subject: [PATCH 29/29] hooks --- src/graphnet/datasets/prometheus_datasets.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/graphnet/datasets/prometheus_datasets.py b/src/graphnet/datasets/prometheus_datasets.py index 2248403c7..1f3435e4b 100644 --- a/src/graphnet/datasets/prometheus_datasets.py +++ b/src/graphnet/datasets/prometheus_datasets.py @@ -36,19 +36,19 @@ def _prepare_args( """Prepare arguments for dataset. Args: - backend: backend of dataset. Either "parquet" or "sqlite" + backend: backend of dataset. Either "parquet" or "sqlite". features: List of features from user to use as input. truth: List of event-level truth variables from user. Returns: Dataset arguments, train/val selection, test selection """ - if backend == 'sqlite': + if backend == "sqlite": dataset_paths = glob(os.path.join(self.dataset_dir, "*.db")) assert len(dataset_paths) == 1 dataset_path = dataset_paths[0] event_nos = query_database( - database=dataset_path, - query=f"SELECT event_no FROM {self._truth_table[0]}" + database=dataset_path, + query=f"SELECT event_no FROM {self._truth_table[0]}", ) train_val, test = train_test_split( event_nos["event_no"].tolist(), @@ -56,9 +56,13 @@ def _prepare_args( random_state=42, shuffle=True, ) - elif backend == 'parquet': + elif backend == "parquet": dataset_path = self.dataset_dir - n_batches = len(glob(os.path.join(dataset_path,self._truth_table,'*.parquet'))) + n_batches = len( + glob( + os.path.join(dataset_path, self._truth_table, "*.parquet") + ) + ) train_val, test = train_test_split( np.arange(0, n_batches), test_size=0.10,