From 3ff3fe73f8fee3d2c149584f48a41f660ddf39a5 Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Tue, 20 Aug 2024 22:56:57 +0800 Subject: [PATCH] only auto set ignored_index in multiclass Signed-off-by: Zhiyuan Chen --- danling/metrics/functional.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/danling/metrics/functional.py b/danling/metrics/functional.py index 0ce12f7f..1722cc60 100644 --- a/danling/metrics/functional.py +++ b/danling/metrics/functional.py @@ -21,6 +21,7 @@ from collections.abc import Sequence import torch +from chanfig.utils import NULL, Null from lazy_imports import try_import from torch import Tensor @@ -40,12 +41,14 @@ def auroc( num_labels: int | None = None, num_classes: int | None = None, task_weights: Tensor | None = None, - ignored_index: int = -100, + ignored_index: int | None | NULL = Null, **kwargs, ): te.check() if num_classes and num_labels: raise ValueError("Only one of num_classes or num_labels can be specified, but not both") + if ignored_index is Null: + ignored_index = -100 if num_classes else None input, target = preprocess(input, target, ignored_index=ignored_index) if num_labels is None and num_classes is None: return tef.binary_auroc(input=input, target=target, weight=weight, **kwargs) @@ -66,12 +69,14 @@ def auprc( num_labels: int | None = None, num_classes: int | None = None, task_weights: Tensor | None = None, - ignored_index: int = -100, + ignored_index: int | None | NULL = Null, **kwargs, ): te.check() if num_classes and num_labels: raise ValueError("Only one of num_classes or num_labels can be specified, but not both") + if ignored_index is Null: + ignored_index = -100 if num_classes else None input, target = preprocess(input, target, ignored_index=ignored_index) if num_labels is None and num_classes is None: return tef.binary_auprc(input=input, target=target, **kwargs) @@ -92,12 +97,14 @@ def accuracy( average: str | None = "micro", num_labels: int | None = None, num_classes: int | None = None, - ignored_index: int = -100, + ignored_index: int | None | NULL = Null, **kwargs, ): te.check() if num_classes and num_labels: raise ValueError("Only one of num_classes or num_labels can be specified, but not both") + if ignored_index is Null: + ignored_index = -100 if num_classes else None input, target = preprocess(input, target, ignored_index=ignored_index) if num_labels is None and num_classes is None: return tef.binary_accuracy(input=input, target=target, threshold=threshold, **kwargs) @@ -114,7 +121,7 @@ def mcc( threshold: float = 0.5, num_labels: int | None = None, num_classes: int | None = None, - ignored_index: int = -100, + ignored_index: int | None | NULL = Null, ): tm.check() if num_classes and num_labels: @@ -124,6 +131,8 @@ def mcc( task = "multiclass" if num_labels: task = "multilabel" + if ignored_index is Null: + ignored_index = -100 if num_classes else None input, target = preprocess(input, target, ignored_index=ignored_index) try: return tmf.matthews_corrcoef(