Skip to content

Commit

Permalink
replace with identity possibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed Aug 30, 2024
1 parent 019bb3f commit 832df6c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
15 changes: 11 additions & 4 deletions src/graphnet/models/detector/detector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Base detector-specific `Model` class(es)."""

from abc import abstractmethod
from typing import Dict, Callable, List
from typing import Dict, Callable, List, Optional

from torch_geometric.data import Data
import torch
Expand All @@ -14,10 +14,13 @@
class Detector(Model):
"""Base class for all detector-specific read-ins in graphnet."""

def __init__(self) -> None:
def __init__(
self, replace_with_identity: Optional[List[bool]] = None
) -> None:
"""Construct `Detector`."""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
self._replace_with_identity = replace_with_identity

@abstractmethod
def feature_map(self) -> Dict[str, Callable]:
Expand Down Expand Up @@ -64,9 +67,13 @@ def sensor_index_name(self) -> str:
def _standardize(
self, input_features: torch.tensor, input_feature_names: List[str]
) -> Data:
for idx, feature in enumerate(input_feature_names):
feature_map = self.feature_map()
if self._replace_with_identity is not None:
for feature in self._replace_with_identity:
feature_map[feature] = self._identity # type: ignore
for idx, feature in enumerate(input_feature_names): # type: ignore
try:
input_features[:, idx] = self.feature_map()[feature]( # type: ignore
input_features[:, idx] = feature_map[feature]( # type: ignore
input_features[:, idx]
)
except KeyError as e:
Expand Down
10 changes: 7 additions & 3 deletions src/graphnet/models/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def __init__(

self._transform_prediction_training: Callable[
[Tensor], Tensor
] = lambda x: x
] = self._identity
self._transform_prediction_inference: Callable[
[Tensor], Tensor
] = lambda x: x
self._transform_target: Callable[[Tensor], Tensor] = lambda x: x
] = self._identity
self._transform_target: Callable[[Tensor], Tensor] = self._identity
self._validate_and_set_transforms(
transform_prediction_and_target,
transform_target,
Expand Down Expand Up @@ -217,6 +217,10 @@ def _validate_and_set_transforms(
if transform_inference is not None:
self._transform_prediction_inference = transform_inference

def _identity(self, x: Tensor) -> Tensor:
"""Identity function."""
return x


class LearnedTask(Task):
"""Task class with a learned mapping.
Expand Down

0 comments on commit 832df6c

Please sign in to comment.