diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index 64291aa2..a1e1d124 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -1448,4 +1448,42 @@ defmodule Scholar.Metrics.Classification do assert_rank!(y_true, 1) assert_same_shape!(y_true, y_pred) end + + @doc """ + Matthews Correlation Coefficient (MCC) provides a measure of the quality of binary classifications. + + It returns a value between -1 and 1 where 1 represents a perfect prediction, 0 represents no better + than random prediction, and -1 indicates total disagreement between prediction and observation. + """ + defn mcc(y_true, y_pred) do + true_positives = binary_true_positives(y_true, y_pred) + true_negatives = binary_true_negatives(y_true, y_pred) + false_positives = binary_false_positives(y_true, y_pred) + false_negatives = binary_false_negatives(y_true, y_pred) + + mcc_numerator = true_positives * true_negatives - false_positives * false_negatives + + mcc_denominator = + Nx.sqrt( + (true_positives + false_positives) * + (true_positives + false_negatives) * + (true_negatives + false_positives) * + (true_negatives + false_negatives) + ) + + zero_tensor = Nx.tensor([0.0], type: :f32) + + if Nx.all( + true_positives == zero_tensor and + true_negatives == zero_tensor + ) do + Nx.tensor([-1.0], type: :f32) + else + Nx.select( + mcc_denominator == zero_tensor, + zero_tensor, + mcc_numerator / mcc_denominator + ) + end + end end diff --git a/test/scholar/metrics/classification_test.exs b/test/scholar/metrics/classification_test.exs index 7667e2df..975abed0 100644 --- a/test/scholar/metrics/classification_test.exs +++ b/test/scholar/metrics/classification_test.exs @@ -34,4 +34,42 @@ defmodule Scholar.Metrics.ClassificationTest do assert_all_close(fbeta_scores, Classification.precision(y_true, y_pred, num_classes: 2)) end end + + describe "mcc/2" do + test "returns 1 for perfect predictions" do + y_true = Nx.tensor([1, 0, 1, 0, 1]) + y_pred = Nx.tensor([1, 0, 1, 0, 1]) + assert Classification.mcc(y_true, y_pred) == Nx.tensor([1.0], type: :f32) + end + + test "returns -1 for completely wrong predictions" do + y_true = Nx.tensor([1, 0, 1, 0, 1]) + y_pred = Nx.tensor([0, 1, 0, 1, 0]) + assert Classification.mcc(y_true, y_pred) == Nx.tensor([-1.0], type: :f32) + end + + test "returns 0 when all predictions are positive" do + y_true = Nx.tensor([1, 0, 1, 0, 1]) + y_pred = Nx.tensor([1, 1, 1, 1, 1]) + assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.0], type: :f32) + end + + test "returns 0 when all predictions are negative" do + y_true = Nx.tensor([1, 0, 1, 0, 1]) + y_pred = Nx.tensor([0, 0, 0, 0, 0]) + assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.0], type: :f32) + end + + test "computes MCC for generic case" do + y_true = Nx.tensor([1, 0, 1, 0, 1]) + y_pred = Nx.tensor([1, 0, 1, 1, 1]) + assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.6123723983764648], type: :f32) + end + + test "returns 0 when TP, TN, FP, and FN are all 0" do + y_true = Nx.tensor([0, 0, 0, 0, 0]) + y_pred = Nx.tensor([0, 0, 0, 0, 0]) + assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.0], type: :f32) + end + end end