diff --git a/src/graphnet/models/detector/detector.py b/src/graphnet/models/detector/detector.py index 9b9fc61b0..599865bb9 100644 --- a/src/graphnet/models/detector/detector.py +++ b/src/graphnet/models/detector/detector.py @@ -1,7 +1,7 @@ """Base detector-specific `Model` class(es).""" from abc import abstractmethod -from typing import Dict, Callable, List +from typing import Dict, Callable, List, Optional from torch_geometric.data import Data import torch @@ -14,10 +14,13 @@ class Detector(Model): """Base class for all detector-specific read-ins in graphnet.""" - def __init__(self) -> None: + def __init__( + self, replace_with_identity: Optional[List[bool]] = None + ) -> None: """Construct `Detector`.""" # Base class constructor super().__init__(name=__name__, class_name=self.__class__.__name__) + self._replace_with_identity = replace_with_identity @abstractmethod def feature_map(self) -> Dict[str, Callable]: @@ -64,9 +67,13 @@ def sensor_index_name(self) -> str: def _standardize( self, input_features: torch.tensor, input_feature_names: List[str] ) -> Data: - for idx, feature in enumerate(input_feature_names): + feature_map = self.feature_map() + if self._replace_with_identity is not None: + for feature in self._replace_with_identity: + feature_map[feature] = self._identity # type: ignore + for idx, feature in enumerate(input_feature_names): # type: ignore try: - input_features[:, idx] = self.feature_map()[feature]( # type: ignore + input_features[:, idx] = feature_map[feature]( # type: ignore input_features[:, idx] ) except KeyError as e: diff --git a/src/graphnet/models/task/task.py b/src/graphnet/models/task/task.py index cd750f35d..882627440 100644 --- a/src/graphnet/models/task/task.py +++ b/src/graphnet/models/task/task.py @@ -107,11 +107,11 @@ def __init__( self._transform_prediction_training: Callable[ [Tensor], Tensor - ] = lambda x: x + ] = self._identity self._transform_prediction_inference: Callable[ [Tensor], Tensor - ] = lambda x: x - self._transform_target: Callable[[Tensor], Tensor] = lambda x: x + ] = self._identity + self._transform_target: Callable[[Tensor], Tensor] = self._identity self._validate_and_set_transforms( transform_prediction_and_target, transform_target, @@ -217,6 +217,10 @@ def _validate_and_set_transforms( if transform_inference is not None: self._transform_prediction_inference = transform_inference + def _identity(self, x: Tensor) -> Tensor: + """Identity function.""" + return x + class LearnedTask(Task): """Task class with a learned mapping.