Skip to content

Commit

Permalink
Matthews Correlation Coefficient (#203)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulsullivanjr authored Nov 5, 2023
1 parent 6f2aa17 commit 0f95b58
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 0 deletions.
38 changes: 38 additions & 0 deletions lib/scholar/metrics/classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1448,4 +1448,42 @@ defmodule Scholar.Metrics.Classification do
assert_rank!(y_true, 1)
assert_same_shape!(y_true, y_pred)
end

@doc """
Matthews Correlation Coefficient (MCC) provides a measure of the quality of binary classifications.
It returns a value between -1 and 1 where 1 represents a perfect prediction, 0 represents no better
than random prediction, and -1 indicates total disagreement between prediction and observation.
"""
defn mcc(y_true, y_pred) do
true_positives = binary_true_positives(y_true, y_pred)
true_negatives = binary_true_negatives(y_true, y_pred)
false_positives = binary_false_positives(y_true, y_pred)
false_negatives = binary_false_negatives(y_true, y_pred)

mcc_numerator = true_positives * true_negatives - false_positives * false_negatives

mcc_denominator =
Nx.sqrt(
(true_positives + false_positives) *
(true_positives + false_negatives) *
(true_negatives + false_positives) *
(true_negatives + false_negatives)
)

zero_tensor = Nx.tensor([0.0], type: :f32)

if Nx.all(
true_positives == zero_tensor and
true_negatives == zero_tensor
) do
Nx.tensor([-1.0], type: :f32)
else
Nx.select(
mcc_denominator == zero_tensor,
zero_tensor,
mcc_numerator / mcc_denominator
)
end
end
end
38 changes: 38 additions & 0 deletions test/scholar/metrics/classification_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,42 @@ defmodule Scholar.Metrics.ClassificationTest do
assert_all_close(fbeta_scores, Classification.precision(y_true, y_pred, num_classes: 2))
end
end

describe "mcc/2" do
test "returns 1 for perfect predictions" do
y_true = Nx.tensor([1, 0, 1, 0, 1])
y_pred = Nx.tensor([1, 0, 1, 0, 1])
assert Classification.mcc(y_true, y_pred) == Nx.tensor([1.0], type: :f32)
end

test "returns -1 for completely wrong predictions" do
y_true = Nx.tensor([1, 0, 1, 0, 1])
y_pred = Nx.tensor([0, 1, 0, 1, 0])
assert Classification.mcc(y_true, y_pred) == Nx.tensor([-1.0], type: :f32)
end

test "returns 0 when all predictions are positive" do
y_true = Nx.tensor([1, 0, 1, 0, 1])
y_pred = Nx.tensor([1, 1, 1, 1, 1])
assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.0], type: :f32)
end

test "returns 0 when all predictions are negative" do
y_true = Nx.tensor([1, 0, 1, 0, 1])
y_pred = Nx.tensor([0, 0, 0, 0, 0])
assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.0], type: :f32)
end

test "computes MCC for generic case" do
y_true = Nx.tensor([1, 0, 1, 0, 1])
y_pred = Nx.tensor([1, 0, 1, 1, 1])
assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.6123723983764648], type: :f32)
end

test "returns 0 when TP, TN, FP, and FN are all 0" do
y_true = Nx.tensor([0, 0, 0, 0, 0])
y_pred = Nx.tensor([0, 0, 0, 0, 0])
assert Classification.mcc(y_true, y_pred) == Nx.tensor([0.0], type: :f32)
end
end
end

0 comments on commit 0f95b58

Please sign in to comment.