diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index df28f3191..0c86663ba 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -1,7 +1,7 @@ """Base detector-specific `Model` class(es).""" from abc import abstractmethod -from typing import Dict, Callable, List +from typing import Dict, Callable, List, Optional from torch_geometric.data import Data import torch @@ -14,10 +14,19 @@ class Detector(Model): """Base class for all detector-specific read-ins in graphnet.""" - def __init__(self) -> None: - """Construct `Detector`.""" + def __init__( + self, replace_with_identity: Optional[List[str]] = None + ) -> None: + """Construct `Detector`. + + Args: + replace_with_identity: A list of feature names from the + feature_map that should be replaced with the identity + function. + """ # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) + self._replace_with_identity = replace_with_identity @abstractmethod def feature_map(self) -> Dict[str, Callable]: @@ -64,9 +73,13 @@ def sensor_index_name(self) -> str: def _standardize( self, input_features: torch.tensor, input_feature_names: List[str] ) -> Data: + feature_map = self.feature_map() + if self._replace_with_identity is not None: + for feature in self._replace_with_identity: + feature_map[feature] = self._identity for idx, feature in enumerate(input_feature_names): try: - input_features[:, idx] = self.feature_map()[ + input_features[:, idx] = feature_map[ feature ]( # noqa: E501 # type: ignore input_features[:, idx]