Skip to content

Commit

Permalink
Merge pull request graphnet-team#769 from Aske-Rosted/replace_with_id…
Browse files Browse the repository at this point in the history
…entity

Optional replacing of feature mapping with identity
  • Loading branch information
Aske-Rosted authored Nov 29, 2024
2 parents 4898fb4 + 0966f89 commit c826a2d
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions src/graphnet/models/detector/detector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base detector-specific `Model` class(es)."""

from abc import abstractmethod
from typing import Dict, Callable, List
from typing import Dict, Callable, List, Optional

from torch_geometric.data import Data
import torch
Expand All @@ -14,10 +14,19 @@
class Detector(Model):
"""Base class for all detector-specific read-ins in graphnet."""

def __init__(self) -> None:
"""Construct `Detector`."""
def __init__(
self, replace_with_identity: Optional[List[str]] = None
) -> None:
"""Construct `Detector`.
Args:
replace_with_identity: A list of feature names from the
feature_map that should be replaced with the identity
function.
"""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
self._replace_with_identity = replace_with_identity

@abstractmethod
def feature_map(self) -> Dict[str, Callable]:
Expand Down Expand Up @@ -64,9 +73,13 @@ def sensor_index_name(self) -> str:
def _standardize(
self, input_features: torch.tensor, input_feature_names: List[str]
) -> Data:
feature_map = self.feature_map()
if self._replace_with_identity is not None:
for feature in self._replace_with_identity:
feature_map[feature] = self._identity
for idx, feature in enumerate(input_feature_names):
try:
input_features[:, idx] = self.feature_map()[
input_features[:, idx] = feature_map[
feature
]( # noqa: E501 # type: ignore
input_features[:, idx]
Expand Down

0 comments on commit c826a2d

Please sign in to comment.