Skip to content

Commit

Permalink
Merge pull request #623 from RasmusOrsoe/geometry_tables_updated
Browse files Browse the repository at this point in the history
Geometry Tables
  • Loading branch information
RasmusOrsoe authored Nov 6, 2023
2 parents d1b97d5 + 035b81b commit cddc567
Show file tree
Hide file tree
Showing 15 changed files with 322 additions and 16 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,9 @@ data/examples/output/
**.parquet
# Exception to pre-trained folders; here we want all .pth files
!/graphnet/src/graphnet/models/pretrained/**/**/**/**/**.pth
# Exception to geometry tables
!/data/geometry_tables/**/**.parquet
!/data/tests/sqlite/upgrade_genie_step4_140028_000998_first_5_frames/upgrade_genie_step4_140028_000998_first_5_frames.db
!/data/examples/sqlite/prometheus/prometheus-events.db
# Notebooks
**.ipynb
Binary file modified data/examples/parquet/prometheus/prometheus-events.parquet
Binary file not shown.
Binary file modified data/examples/sqlite/prometheus/prometheus-events.db
Binary file not shown.
Binary file added data/geometry_tables/icecube/icecube86.parquet
Binary file not shown.
Binary file not shown.
Binary file added data/geometry_tables/prometheus/orca_150.parquet
Binary file not shown.
Binary file not shown.
4 changes: 1 addition & 3 deletions examples/04_training/02_train_tito_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def main(
) = make_train_validation_dataloader(
db=config["path"],
graph_definition=graph_definition,
selection=list(
range(0, 100)
), # subset of events for speeding up training
selection=None,
pulsemaps=config["pulsemap"],
features=features,
truth=truth,
Expand Down
5 changes: 5 additions & 0 deletions src/graphnet/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@
PRETRAINED_MODEL_DIR = os.path.join(
GRAPHNET_ROOT_DIR, "src", "graphnet", "models", "pretrained"
)

# Geometry Tables
GEOMETRY_TABLE_DIR = os.path.join(DATA_DIR, "geometry_tables")
ICECUBE_GEOMETRY_TABLE_DIR = os.path.join(GEOMETRY_TABLE_DIR, "icecube")
PROMETHEUS_GEOMETRY_TABLE_DIR = os.path.join(GEOMETRY_TABLE_DIR, "prometheus")
2 changes: 1 addition & 1 deletion src/graphnet/models/detector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Detector-specific modules, for data ingestion and standardisation."""

from .icecube import IceCube86, IceCubeDeepCore
from .icecube import IceCube86, IceCubeDeepCore, IceCubeUpgrade
from .detector import Detector
45 changes: 38 additions & 7 deletions src/graphnet/models/detector/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from torch_geometric.data import Data
import torch
import pandas as pd

from graphnet.models import Model
from graphnet.utilities.decorators import final
Expand All @@ -24,26 +25,56 @@ def feature_map(self) -> Dict[str, Callable]:

@final
def forward( # type: ignore
self, node_features: torch.tensor, node_feature_names: List[str]
self, input_features: torch.tensor, input_feature_names: List[str]
) -> Data:
"""Pre-process graph `Data` features and build graph adjacency."""
return self._standardize(node_features, node_feature_names)
return self._standardize(input_features, input_feature_names)

@property
def geometry_table(self) -> pd.DataFrame:
"""Public get method for retrieving a `Detector`s geometry table."""
if ~hasattr(self, "_geometry_table"):
try:
assert hasattr(self, "geometry_table_path")
except AssertionError as e:
self.error(
f"""{self.__class__.__name__} does not have class
variable `geometry_table_path` set."""
)
raise e
self._geometry_table = pd.read_parquet(self.geometry_table_path)
return self._geometry_table

@property
def string_index_name(self) -> str:
"""Public get method for retrieving the string index column name."""
return self.string_id_column

@property
def sensor_position_names(self) -> List[str]:
"""Public get method for retrieving the xyz coordinate column names."""
return self.xyz

@property
def sensor_index_name(self) -> str:
"""Public get method for retrieving the sensor id column name."""
return self.sensor_id_column

@final
def _standardize(
self, node_features: torch.tensor, node_feature_names: List[str]
self, input_features: torch.tensor, input_feature_names: List[str]
) -> Data:
for idx, feature in enumerate(node_feature_names):
for idx, feature in enumerate(input_feature_names):
try:
node_features[:, idx] = self.feature_map()[feature]( # type: ignore
node_features[:, idx]
input_features[:, idx] = self.feature_map()[feature]( # type: ignore
input_features[:, idx]
)
except KeyError as e:
self.warning(
f"""No Standardization function found for '{feature}'"""
)
raise e
return node_features
return input_features

def _identity(self, x: torch.tensor) -> torch.tensor:
"""Apply no standardization to input."""
Expand Down
18 changes: 17 additions & 1 deletion src/graphnet/models/detector/icecube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@

from typing import Dict, Callable
import torch
import os

from graphnet.models.detector.detector import Detector
from graphnet.constants import ICECUBE_GEOMETRY_TABLE_DIR


class IceCube86(Detector):
"""`Detector` class for IceCube-86."""

geometry_table_path = os.path.join(
ICECUBE_GEOMETRY_TABLE_DIR, "icecube86.parquet"
)
xyz = ["dom_x", "dom_y", "dom_z"]
string_id_column = "string"
sensor_id_column = "sensor_id"

def feature_map(self) -> Dict[str, Callable]:
"""Map standardization functions to each dimension of input data."""
feature_map = {
Expand Down Expand Up @@ -63,7 +72,7 @@ def _charge(self, x: torch.tensor) -> torch.tensor:
return torch.log10(x) / 3.0


class IceCubeDeepCore(Detector):
class IceCubeDeepCore(IceCube86):
"""`Detector` class for IceCube-DeepCore."""

def feature_map(self) -> Dict[str, Callable]:
Expand Down Expand Up @@ -98,6 +107,13 @@ def _pmt_area(self, x: torch.tensor) -> torch.tensor:
class IceCubeUpgrade(Detector):
"""`Detector` class for IceCube-Upgrade."""

geometry_table_path = os.path.join(
ICECUBE_GEOMETRY_TABLE_DIR, "icecube_upgrade.parquet"
)
xyz = ["dom_x", "dom_y", "dom_z"]
string_id_column = "string"
sensor_id_column = "sensor_id"

def feature_map(self) -> Dict[str, Callable]:
"""Map standardization functions to each dimension of input data."""
feature_map = {
Expand Down
17 changes: 15 additions & 2 deletions src/graphnet/models/detector/prometheus.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@

from typing import Dict, Callable
import torch
import os

from graphnet.models.detector.detector import Detector
from graphnet.constants import PROMETHEUS_GEOMETRY_TABLE_DIR


class Prometheus(Detector):
class ORCA150(Detector):
"""`Detector` class for Prometheus prototype."""

geometry_table_path = os.path.join(
PROMETHEUS_GEOMETRY_TABLE_DIR, "orca_150.parquet"
)
xyz = ["sensor_pos_x", "sensor_pos_y", "sensor_pos_z"]
string_id_column = "sensor_string_id"
sensor_id_column = "sensor_id"

def feature_map(self) -> Dict[str, Callable]:
"""Map standardization functions to each dimension."""
feature_map = {
Expand All @@ -26,4 +35,8 @@ def _sensor_pos_z(self, x: torch.tensor) -> torch.tensor:
return (x + 350) / 100

def _t(self, x: torch.tensor) -> torch.tensor:
return ((x / 1.05e04) - 1.0) * 20.0
return x / 1.05e04


class Prometheus(ORCA150):
"""Reference to ORCA150."""
116 changes: 115 additions & 1 deletion src/graphnet/models/graphs/graph_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def __init__(
dtype: Optional[torch.dtype] = torch.float,
perturbation_dict: Optional[Dict[str, float]] = None,
seed: Optional[Union[int, Generator]] = None,
add_inactive_sensors: bool = False,
sensor_mask: Optional[List[int]] = None,
string_mask: Optional[List[int]] = None,
sort_by: str = None,
):
"""Construct ´GraphDefinition´. The ´detector´ holds.
Expand All @@ -55,6 +59,12 @@ def __init__(
to None.
seed: seed or Generator used to randomly sample perturbations.
Defaults to None.
add_inactive_sensors: If True, inactive sensors will be appended
to the graph with padded pulse information. Defaults to False.
sensor_mask: A list of sensor id's to be masked from the graph. Any
sensor listed here will be removed from the graph. Defaults to None.
string_mask: A list of string id's to be masked from the graph. Defaults to None.
sort_by: Name of node feature to sort by. Defaults to None.
"""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
Expand All @@ -64,6 +74,11 @@ def __init__(
self._edge_definition = edge_definition
self._node_definition = node_definition
self._perturbation_dict = perturbation_dict
self._sensor_mask = sensor_mask
self._string_mask = string_mask
self._add_inactive_sensors = add_inactive_sensors

self._resolve_masks()

if input_feature_names is None:
# Assume all features in Detector is used.
Expand All @@ -74,7 +89,19 @@ def __init__(
self._node_definition.set_output_feature_names(
self._input_feature_names
)

self.output_feature_names = self._node_definition._output_feature_names

# Sorting
if sort_by is not None:
assert isinstance(sort_by, str)
try:
sort_by = self.output_feature_names.index(sort_by) # type: ignore
except ValueError as e:
self.error(
f"{sort_by} not in node features {self.output_feature_names}."
)
raise e
self._sort_by = sort_by
# Set data type
self.to(dtype)

Expand Down Expand Up @@ -138,6 +165,18 @@ def forward( # type: ignore
input_feature_names=input_feature_names,
)

# Add inactive sensors if `add_inactive_sensors = True`
if self._add_inactive_sensors:
input_features = self._attach_inactive_sensors(
input_features, input_feature_names
)

# Mask out sensors if `sensor_mask` is given
if self._sensor_mask is not None:
input_features = self._mask_sensors(
input_features, input_feature_names
)

# Gaussian perturbation of each column if perturbation dict is given
input_features = self._perturb_input(input_features)

Expand All @@ -149,6 +188,8 @@ def forward( # type: ignore

# Create graph & get new node feature names
graph, node_feature_names = self._node_definition(input_features)
if self._sort_by is not None:
graph.x = graph.x[graph.x[:, self._sort_by].sort()[1]]

# Enforce dtype
graph.x = graph.x.type(self.dtype)
Expand Down Expand Up @@ -197,6 +238,79 @@ def forward( # type: ignore
graph["graph_definition"] = self.__class__.__name__
return graph

def _resolve_masks(self) -> None:
"""Handle cases with sensor/string masks."""
if self._sensor_mask is not None:
if self._string_mask is not None:
assert (
1 == 2
), """Got arguments for both `sensor_mask`and `string_mask`. Please specify only one. """

if (self._sensor_mask is None) & (self._string_mask is not None):
self._sensor_mask = self._convert_string_to_sensor_mask()

return

def _convert_string_to_sensor_mask(self) -> List[int]:
"""Convert a string mask to a sensor mask."""
string_id_column = self._detector.string_id_column
sensor_id_column = self._detector.sensor_id_column
geometry_table = self._detector.geometry_table
idx = geometry_table[string_id_column].isin(self._string_mask)
return np.asarray(geometry_table.loc[idx, sensor_id_column]).tolist()

def _attach_inactive_sensors(
self, input_features: np.ndarray, input_feature_names: List[str]
) -> np.ndarray:
"""Attach inactive sensors to `input_features`.
This function will query the detector geometry table and add any sensor
in the geometry table that is not already present in `node_features`.
"""
lookup = self._geometry_table_lookup(
input_features, input_feature_names
)
geometry_table = self._detector.geometry_table
unique_sensors = geometry_table.reset_index(drop=True)

# multiple lines to avoid long line:
inactive_idx = ~geometry_table.index.isin(lookup)
inactive_sensors = unique_sensors.loc[
inactive_idx, input_feature_names
]
input_features = np.concatenate(
[input_features, inactive_sensors.to_numpy()], axis=0
)
return input_features

def _mask_sensors(
self, input_features: np.ndarray, input_feature_names: List[str]
) -> np.ndarray:
"""Mask sensors according to `sensor_mask`."""
sensor_id_column = self._detector.sensor_index_name
geometry_table = self._detector.geometry_table

lookup = self._geometry_table_lookup(
input_features=input_features,
input_feature_names=input_feature_names,
)
mask = ~geometry_table.loc[lookup, sensor_id_column].isin(
self._sensor_mask
)

return input_features[mask, :]

def _geometry_table_lookup(
self, input_features: np.ndarray, input_feature_names: List[str]
) -> np.ndarray:
"""Convert xyz in `input_features` into a set of sensor ids."""
lookup_columns = [
input_feature_names.index(feature)
for feature in self._detector.sensor_position_names
]
idx = [*zip(*[tuple(input_features[:, k]) for k in lookup_columns])]
return self._detector.geometry_table.loc[idx, :].index

def _validate_input(
self, input_features: np.array, input_feature_names: List[str]
) -> None:
Expand Down
Loading

0 comments on commit cddc567

Please sign in to comment.