Skip to content

Commit

Permalink
only auto set ignored_index in multiclass
Browse files Browse the repository at this point in the history
Signed-off-by: Zhiyuan Chen <[email protected]>
  • Loading branch information
ZhiyuanChen committed Aug 20, 2024
1 parent fcc57bf commit 3ff3fe7
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions danling/metrics/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from collections.abc import Sequence

import torch
from chanfig.utils import NULL, Null
from lazy_imports import try_import
from torch import Tensor

Expand All @@ -40,12 +41,14 @@ def auroc(
num_labels: int | None = None,
num_classes: int | None = None,
task_weights: Tensor | None = None,
ignored_index: int = -100,
ignored_index: int | None | NULL = Null,
**kwargs,
):
te.check()
if num_classes and num_labels:
raise ValueError("Only one of num_classes or num_labels can be specified, but not both")
if ignored_index is Null:
ignored_index = -100 if num_classes else None
input, target = preprocess(input, target, ignored_index=ignored_index)
if num_labels is None and num_classes is None:
return tef.binary_auroc(input=input, target=target, weight=weight, **kwargs)
Expand All @@ -66,12 +69,14 @@ def auprc(
num_labels: int | None = None,
num_classes: int | None = None,
task_weights: Tensor | None = None,
ignored_index: int = -100,
ignored_index: int | None | NULL = Null,
**kwargs,
):
te.check()
if num_classes and num_labels:
raise ValueError("Only one of num_classes or num_labels can be specified, but not both")
if ignored_index is Null:
ignored_index = -100 if num_classes else None
input, target = preprocess(input, target, ignored_index=ignored_index)
if num_labels is None and num_classes is None:
return tef.binary_auprc(input=input, target=target, **kwargs)
Expand All @@ -92,12 +97,14 @@ def accuracy(
average: str | None = "micro",
num_labels: int | None = None,
num_classes: int | None = None,
ignored_index: int = -100,
ignored_index: int | None | NULL = Null,
**kwargs,
):
te.check()
if num_classes and num_labels:
raise ValueError("Only one of num_classes or num_labels can be specified, but not both")
if ignored_index is Null:
ignored_index = -100 if num_classes else None
input, target = preprocess(input, target, ignored_index=ignored_index)
if num_labels is None and num_classes is None:
return tef.binary_accuracy(input=input, target=target, threshold=threshold, **kwargs)
Expand All @@ -114,7 +121,7 @@ def mcc(
threshold: float = 0.5,
num_labels: int | None = None,
num_classes: int | None = None,
ignored_index: int = -100,
ignored_index: int | None | NULL = Null,
):
tm.check()
if num_classes and num_labels:
Expand All @@ -124,6 +131,8 @@ def mcc(
task = "multiclass"
if num_labels:
task = "multilabel"
if ignored_index is Null:
ignored_index = -100 if num_classes else None
input, target = preprocess(input, target, ignored_index=ignored_index)
try:
return tmf.matthews_corrcoef(
Expand Down

0 comments on commit 3ff3fe7

Please sign in to comment.