Skip to content

Commit

Permalink
Merge pull request #701 from RasmusOrsoe/add_datasets
Browse files Browse the repository at this point in the history
Add public datasets
  • Loading branch information
RasmusOrsoe authored May 2, 2024
2 parents 8d63a04 + 71b56a4 commit 0ac4410
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 0 deletions.
Binary file not shown.
1 change: 1 addition & 0 deletions src/graphnet/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
"""Contains pre-converted datasets ready for training."""
from .test_dataset import TestDataset
from .prometheus_datasets import TRIDENTSmall, BaikalGVDSmall, PONESmall
144 changes: 144 additions & 0 deletions src/graphnet/datasets/prometheus_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""Public datasets from Prometheus Simulation."""
from typing import Dict, Any, List, Tuple, Union
import os
from sklearn.model_selection import train_test_split
from glob import glob
import numpy as np

from graphnet.training.labels import Direction, Track
from graphnet.data import ERDAHostedDataset
from graphnet.data.constants import FEATURES
from graphnet.data.utilities import query_database


class PublicPrometheusDataset(ERDAHostedDataset):
"""A generic class for public Prometheus Datasets hosted using ERDA."""

# Static Member Variables:
_pulsemaps = ["photons"]
_truth_table = "mc_truth"
_event_truth = [
"interaction",
"initial_state_energy",
"initial_state_type",
"initial_state_zenith",
"initial_state_azimuth",
"initial_state_x",
"initial_state_y",
"initial_state_z",
]
_pulse_truth = None
_features = FEATURES.PROMETHEUS

def _prepare_args(
self, backend: str, features: List[str], truth: List[str]
) -> Tuple[Dict[str, Any], Union[List[int], None], Union[List[int], None]]:
"""Prepare arguments for dataset.
Args:
backend: backend of dataset. Either "parquet" or "sqlite".
features: List of features from user to use as input.
truth: List of event-level truth variables from user.
Returns: Dataset arguments, train/val selection, test selection
"""
if backend == "sqlite":
dataset_paths = glob(os.path.join(self.dataset_dir, "*.db"))
assert len(dataset_paths) == 1
dataset_path = dataset_paths[0]
event_nos = query_database(
database=dataset_path,
query=f"SELECT event_no FROM {self._truth_table[0]}",
)
train_val, test = train_test_split(
event_nos["event_no"].tolist(),
test_size=0.10,
random_state=42,
shuffle=True,
)
elif backend == "parquet":
dataset_path = self.dataset_dir
n_batches = len(
glob(
os.path.join(dataset_path, self._truth_table, "*.parquet")
)
)
train_val, test = train_test_split(
np.arange(0, n_batches),
test_size=0.10,
random_state=42,
shuffle=True,
)
dataset_args = {
"truth_table": self._truth_table,
"pulsemaps": self._pulsemaps,
"path": dataset_path,
"graph_definition": self._graph_definition,
"features": features,
"truth": truth,
"labels": {
"direction": Direction(
azimuth_key="initial_state_azimuth",
zenith_key="initial_state_zenith",
),
"track": Track(
pid_key="initial_state_type", interaction_key="interaction"
),
},
}

return dataset_args, train_val, test


class TRIDENTSmall(PublicPrometheusDataset):
"""Public Dataset for Prometheus simulation of a TRIDENT geometry.
Contains ~ 1 million track events between 10 GeV - 10 TeV.
"""

_experiment = "TRIDENT Prometheus Simulation"
_creator = "Rasmus F. Ørsøe"
_comments = (
"Contains ~1 million track events."
" Simulation produced by Stephan Meighen-Berger, "
"U. Melbourne."
)
_available_backends = ["sqlite"]
_file_hashes = {"sqlite": "aooZEpVsAM"}
_citation = None


class PONESmall(PublicPrometheusDataset):
"""Public Dataset for Prometheus simulation of a P-ONE geometry.
Contains ~ 1 million track events between 10 GeV - 10 TeV.
"""

_experiment = "P-ONE Prometheus Simulation"
_creator = "Rasmus F. Ørsøe"
_comments = (
"Contains ~1 million track events."
" Simulation produced by Stephan Meighen-Berger, "
"U. Melbourne."
)
_available_backends = ["sqlite"]
_file_hashes = {"sqlite": "GIt0hlG9qI"}
_citation = None


class BaikalGVDSmall(PublicPrometheusDataset):
"""Public Dataset for Prometheus simulation of a Baikal-GVD geometry.
Contains ~ 1 million track events between 10 GeV - 10 TeV.
"""

_experiment = "Baikal-GVD Prometheus Simulation"
_creator = "Rasmus F. Ørsøe"
_comments = (
"Contains ~1 million track events."
" Simulation produced by Stephan Meighen-Berger, "
"U. Melbourne."
)
_available_backends = ["sqlite"]
_file_hashes = {"sqlite": "FtFs5fxXB7"}
_citation = None
27 changes: 27 additions & 0 deletions src/graphnet/models/detector/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,5 +335,32 @@ def _t(self, x: torch.tensor) -> torch.tensor:
return x / 1.05e04


class PONETriangle(Detector):
"""`Detector` class for Prometheus PONE Triangle."""

geometry_table_path = os.path.join(
PROMETHEUS_GEOMETRY_TABLE_DIR, "pone_triangle.parquet"
)
xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"]
string_id_column = "sensor_string_id"
sensor_id_column = "sensor_id"

def feature_map(self) -> Dict[str, Callable]:
"""Map standardization functions to each dimension."""
feature_map = {
"sensor_pos_x": self._sensor_pos_xyz,
"sensor_pos_y": self._sensor_pos_xyz,
"sensor_pos_z": self._sensor_pos_xyz,
"t": self._t,
}
return feature_map

def _sensor_pos_xyz(self, x: torch.tensor) -> torch.tensor:
return x / 100

def _t(self, x: torch.tensor) -> torch.tensor:
return x / 1.05e04


class Prometheus(ORCA150SuperDense):
"""Reference to ORCA150SuperDense."""
36 changes: 36 additions & 0 deletions src/graphnet/training/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,39 @@ def __call__(self, graph: Data) -> torch.tensor:
).reshape(-1, 1)
z = torch.cos(graph[self._zenith_key]).reshape(-1, 1)
return torch.cat((x, y, z), dim=1)


class Track(Label):
"""Class for producing NuMuCC label.
Label is set to `1` if the event is a NuMu CC event, else `0`.
"""

def __init__(
self,
key: str = "track",
pid_key: str = "pid",
interaction_key: str = "interaction_type",
):
"""Construct `Track` label.
Args:
key: The name of the field in `Data` where the label will be
stored. That is, `graph[key] = label`.
pid_key: The name of the pre-existing key in `graph` that will
be used to access the pdg encoding, used when calculating
the direction.
interaction_key: The name of the pre-existing key in `graph` that
will be used to access the interaction type (1 denoting CC),
used when calculating the direction.
"""
self._pid_key = pid_key
self._int_key = interaction_key

# Base class constructor
super().__init__(key=key)

def __call__(self, graph: Data) -> torch.tensor:
"""Compute label for `graph`."""
label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1)
return label.type(torch.int)

0 comments on commit 0ac4410

Please sign in to comment.