Skip to content

Commit

Permalink
fix: accuracy_score not computing properly for binary classification
Browse files Browse the repository at this point in the history
  • Loading branch information
marcpinet committed Nov 13, 2023
1 parent 79361bd commit e292a21
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions neuralnetlib/metrics.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
import numpy as np

from neuralnetlib.utils import apply_threshold

def accuracy_score(y_pred: np.ndarray, y_true: np.ndarray) -> float:
if y_pred.shape[1] == 1: # Binary classification
y_pred_classes = apply_threshold(y_pred)

def accuracy_score(y_pred: np.ndarray, y_true: np.ndarray, threshold: float = 0.5) -> float:
if y_pred.ndim == 1 or y_pred.shape[1] == 1: # Binary classification
y_pred_classes = apply_threshold(y_pred, threshold).ravel()
else: # Multiclass classification
y_pred_classes = np.argmax(y_pred, axis=1)

if y_true.ndim == 1 or y_true.shape[1] == 1: # If y_true is not one-hot encoded
y_true_classes = y_true.ravel()
else:
y_true_classes = np.argmax(y_true, axis=1)

return np.mean(y_pred_classes == y_true_classes)

return np.mean(y_pred_classes == y_true_classes)

def f1_score(y_pred: np.ndarray, y_true: np.ndarray) -> float:
precision = precision_score(y_pred, y_true)
recall = recall_score(y_pred, y_true)
def f1_score(y_pred: np.ndarray, y_true: np.ndarray, threshold: float = 0.5) -> float:
precision = precision_score(y_pred, y_true, threshold)
recall = recall_score(y_pred, y_true, threshold)
return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0


def recall_score(y_pred: np.ndarray, y_true: np.ndarray) -> float:
y_pred_labels = (y_pred >= 0.5).astype(int) if y_pred.shape[1] == 1 else np.argmax(y_pred, axis=1)
def recall_score(y_pred: np.ndarray, y_true: np.ndarray, threshold: float = 0.5) -> float:
y_pred_labels = apply_threshold(y_pred, threshold) if y_pred.shape[1] == 1 else np.argmax(y_pred, axis=1)
y_true_labels = y_true if y_true.ndim == 1 or y_true.shape[1] == 1 else np.argmax(y_true, axis=1)
classes = np.unique(y_true_labels)
recall_scores = []
Expand All @@ -35,8 +37,9 @@ def recall_score(y_pred: np.ndarray, y_true: np.ndarray) -> float:

return np.mean(recall_scores)

def precision_score(y_pred: np.ndarray, y_true: np.ndarray) -> float:
y_pred_labels = (y_pred >= 0.5).astype(int) if y_pred.shape[1] == 1 else np.argmax(y_pred, axis=1)

def precision_score(y_pred: np.ndarray, y_true: np.ndarray, threshold: float = 0.5) -> float:
y_pred_labels = apply_threshold(y_pred, threshold) if y_pred.shape[1] == 1 else np.argmax(y_pred, axis=1)
y_true_labels = y_true if y_true.ndim == 1 or y_true.shape[1] == 1 else np.argmax(y_true, axis=1)
classes = np.unique(y_true_labels)
precision_scores = []
Expand Down

0 comments on commit e292a21

Please sign in to comment.