diff --git a/src/graphnet/training/labels.py b/src/graphnet/training/labels.py index bacaa16ee..4a0debff9 100644 --- a/src/graphnet/training/labels.py +++ b/src/graphnet/training/labels.py @@ -103,4 +103,4 @@ def __init__( def __call__(self, graph: Data) -> torch.tensor: """Compute label for `graph`.""" label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1) - return torch.tensor(label, dtype=torch.int64) + return label # torch.tensor(label, dtype=torch.int64)