Skip to content

Commit

Permalink
add missing abs to track label
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusOrsoe committed May 20, 2024
1 parent 2efd732 commit b31aa42
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/graphnet/training/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,5 @@ def __init__(

def __call__(self, graph: Data) -> torch.tensor:
"""Compute label for `graph`."""
label = (graph[self._pid_key] == 14) & (graph[self._int_key] == 1)
label = (torch.abs(graph[self._pid_key]) == 14) & (graph[self._int_key] == 1)
return label.type(torch.int)

0 comments on commit b31aa42

Please sign in to comment.