Skip to content

Commit

Permalink
Add precision_recall_fscore_support function (#186)
Browse files Browse the repository at this point in the history
  • Loading branch information
0urobor0s authored Oct 17, 2023
1 parent cdccb19 commit b36df2f
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 11 deletions.
113 changes: 102 additions & 11 deletions lib/scholar/metrics/classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,37 @@ defmodule Scholar.Metrics.Classification do

fbeta_score_schema = f1_score_schema

precision_recall_fscore_support_schema =
general_schema ++
[
average: [
type: {:in, [:micro, :macro, :weighted, :none]},
default: :none,
doc: """
This determines the type of averaging performed on the data.
* `:macro` - Calculate metrics for each label, and find their unweighted mean.
This does not take label imbalance into account.
* `:weighted` - Calculate metrics for each label, and find their average weighted by
support (the number of true instances for each label).
* `:micro` - Calculate metrics globally by counting the total true positives,
false negatives and false positives.
* `:none` - The F-score values for each class are returned.
"""
],
beta: [
type: {:custom, Scholar.Options, :beta, []},
default: 1,
doc: """
Determines the weight of recall in the combined score.
For values of `beta` > 1 it gives more weight to recall, while `beta` < 1 favors precision.
"""
]
]

confusion_matrix_schema =
general_schema ++
[
Expand Down Expand Up @@ -167,6 +198,9 @@ defmodule Scholar.Metrics.Classification do
@cohen_kappa_schema NimbleOptions.new!(cohen_kappa_schema)
@fbeta_score_schema NimbleOptions.new!(fbeta_score_schema)
@f1_score_schema NimbleOptions.new!(f1_score_schema)
@precision_recall_fscore_support_schema NimbleOptions.new!(
precision_recall_fscore_support_schema
)
@brier_score_loss_schema NimbleOptions.new!(brier_score_loss_schema)
@accuracy_schema NimbleOptions.new!(accuracy_schema)
@top_k_accuracy_score_schema NimbleOptions.new!(top_k_accuracy_score_schema)
Expand Down Expand Up @@ -649,12 +683,8 @@ defmodule Scholar.Metrics.Classification do
end

defnp fbeta_score_n(y_true, y_pred, beta, opts) do
check_shape(y_pred, y_true)
num_classes = check_num_classes(opts[:num_classes])
average = opts[:average]

{_precision, _recall, per_class_fscore} =
precision_recall_fscore_n(y_true, y_pred, beta, num_classes, average)
{_precision, _recall, per_class_fscore, _support} =
precision_recall_fscore_support_n(y_true, y_pred, beta, opts)

per_class_fscore
end
Expand All @@ -677,7 +707,66 @@ defmodule Scholar.Metrics.Classification do
end
end

defnp precision_recall_fscore_n(y_true, y_pred, beta, num_classes, average) do
@doc """
Calculates precision, recall, F-score and support for each
class. It also supports a `beta` argument which weights
recall more than precision by it's value.
## Options
#{NimbleOptions.docs(@precision_recall_fscore_support_schema)}
## Examples
iex> y_true = Nx.tensor([0, 1, 1, 1, 1, 0, 2, 1, 0, 1], type: :u32)
iex> y_pred = Nx.tensor([0, 2, 1, 1, 2, 2, 2, 0, 0, 1], type: :u32)
iex> Scholar.Metrics.Classification.precision_recall_fscore_support(y_true, y_pred, num_classes: 3)
{Nx.f32([0.6666666865348816, 1.0, 0.25]),
Nx.f32([0.6666666865348816, 0.5, 1.0]),
Nx.f32([0.6666666865348816, 0.6666666865348816, 0.4000000059604645]),
Nx.u64([3, 6, 1])}
iex> Scholar.Metrics.Classification.precision_recall_fscore_support(y_true, y_pred, num_classes: 3, average: :macro)
{Nx.f32([0.6666666865348816, 1.0, 0.25]),
Nx.f32([0.6666666865348816, 0.5, 1.0]),
Nx.f32(0.5777778029441833),
Nx.Constants.nan()}
iex> Scholar.Metrics.Classification.precision_recall_fscore_support(y_true, y_pred, num_classes: 3, average: :weighted)
{Nx.f32([0.6666666865348816, 1.0, 0.25]),
Nx.f32([0.6666666865348816, 0.5, 1.0]),
Nx.f32(0.6399999856948853),
Nx.Constants.nan()}
iex> Scholar.Metrics.Classification.precision_recall_fscore_support(y_true, y_pred, num_classes: 3, average: :micro)
{Nx.f32(0.6000000238418579),
Nx.f32(0.6000000238418579),
Nx.f32(0.6000000238418579),
Nx.Constants.nan()}
iex> y_true = Nx.tensor([1, 0, 1, 0], type: :u32)
iex> y_pred = Nx.tensor([0, 1, 0, 1], type: :u32)
iex> opts = [beta: 2, num_classes: 2, average: :none]
iex> Scholar.Metrics.Classification.precision_recall_fscore_support(y_true, y_pred, opts)
{Nx.f32([0.0, 0.0]),
Nx.f32([0.0, 0.0]),
Nx.f32([0.0, 0.0]),
Nx.u64([2, 2])}
"""
deftransform precision_recall_fscore_support(y_true, y_pred, opts) do
opts = NimbleOptions.validate!(opts, @precision_recall_fscore_support_schema)
{beta, opts} = Keyword.pop(opts, :beta)

precision_recall_fscore_support_n(
y_true,
y_pred,
beta,
opts
)
end

defnp precision_recall_fscore_support_n(y_true, y_pred, beta, opts) do
check_shape(y_pred, y_true)
num_classes = check_num_classes(opts[:num_classes])
average = opts[:average]

confusion_matrix = confusion_matrix(y_true, y_pred, num_classes: num_classes)
{true_positive, false_positive, false_negative} = fbeta_score_v(confusion_matrix, average)

Expand All @@ -700,13 +789,15 @@ defmodule Scholar.Metrics.Classification do

case average do
:none ->
{precision, recall, per_class_fscore}
support = (y_true == Nx.iota({num_classes, 1})) |> Nx.sum(axes: [1])

{precision, recall, per_class_fscore, support}

:micro ->
{precision, recall, per_class_fscore}
{precision, recall, per_class_fscore, Nx.Constants.nan()}

:macro ->
{precision, recall, Nx.mean(per_class_fscore)}
{precision, recall, Nx.mean(per_class_fscore), Nx.Constants.nan()}

:weighted ->
support = (y_true == Nx.iota({num_classes, 1})) |> Nx.sum(axes: [1])
Expand All @@ -716,7 +807,7 @@ defmodule Scholar.Metrics.Classification do
|> safe_division(Nx.sum(support))
|> Nx.sum()

{precision, recall, per_class_fscore}
{precision, recall, per_class_fscore, Nx.Constants.nan()}
end
end

Expand Down
8 changes: 8 additions & 0 deletions lib/scholar/options.ex
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,12 @@ defmodule Scholar.Options do
{:error,
"expected metric to be a :cosine or tuple {:minkowski, p} where p is a positive number or :infinity, got: #{inspect(metric)}"}
end

def beta(beta) do
if (is_number(beta) and beta >= 0) or (Nx.is_tensor(beta) and Nx.rank(beta) == 0) do
{:ok, beta}
else
{:error, "expect 'beta' to be in the range [0, inf]"}
end
end
end

0 comments on commit b36df2f

Please sign in to comment.