Skip to content

Commit

Permalink
Merge pull request #725 from RasmusOrsoe/fix_track_label
Browse files Browse the repository at this point in the history
add missing `abs` to `Track` label
  • Loading branch information
RasmusOrsoe authored May 28, 2024
2 parents fc5d955 + 1ee2b53 commit e5b9450
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/graphnet/training/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,6 @@ def __init__(

def __call__(self, graph: Data) -> torch.tensor:
"""Compute label for `graph`."""
label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1)
return label.type(torch.int)
is_numu = torch.abs(graph[self._pid_key]) == 14
is_cc = graph[self._int_key] == 1
return (is_numu & is_cc).type(torch.int)

0 comments on commit e5b9450

Please sign in to comment.