diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index d9bf2f7b..373c1b45 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -48,6 +48,37 @@ defmodule Scholar.Metrics.Classification do fbeta_score_schema = f1_score_schema + precision_recall_fscore_support_schema = + general_schema ++ + [ + average: [ + type: {:in, [:micro, :macro, :weighted, :none]}, + default: :none, + doc: """ + This determines the type of averaging performed on the data. + + * `:macro` - Calculate metrics for each label, and find their unweighted mean. + This does not take label imbalance into account. + + * `:weighted` - Calculate metrics for each label, and find their average weighted by + support (the number of true instances for each label). + + * `:micro` - Calculate metrics globally by counting the total true positives, + false negatives and false positives. + + * `:none` - The F-score values for each class are returned. + """ + ], + beta: [ + type: {:custom, Scholar.Options, :beta, []}, + default: 1, + doc: """ + Determines the weight of recall in the combined score. + For values of `beta` > 1 it gives more weight to recall, while `beta` < 1 favors precision. + """ + ] + ] + confusion_matrix_schema = general_schema ++ [ @@ -167,6 +198,9 @@ defmodule Scholar.Metrics.Classification do @cohen_kappa_schema NimbleOptions.new!(cohen_kappa_schema) @fbeta_score_schema NimbleOptions.new!(fbeta_score_schema) @f1_score_schema NimbleOptions.new!(f1_score_schema) + @precision_recall_fscore_support_schema NimbleOptions.new!( + precision_recall_fscore_support_schema + ) @brier_score_loss_schema NimbleOptions.new!(brier_score_loss_schema) @accuracy_schema NimbleOptions.new!(accuracy_schema) @top_k_accuracy_score_schema NimbleOptions.new!(top_k_accuracy_score_schema) @@ -649,12 +683,8 @@ defmodule Scholar.Metrics.Classification do end defnp fbeta_score_n(y_true, y_pred, beta, opts) do - check_shape(y_pred, y_true) - num_classes = check_num_classes(opts[:num_classes]) - average = opts[:average] - - {_precision, _recall, per_class_fscore} = - precision_recall_fscore_n(y_true, y_pred, beta, num_classes, average) + {_precision, _recall, per_class_fscore, _support} = + precision_recall_fscore_support_n(y_true, y_pred, beta, opts) per_class_fscore end @@ -677,7 +707,66 @@ defmodule Scholar.Metrics.Classification do end end - defnp precision_recall_fscore_n(y_true, y_pred, beta, num_classes, average) do + @doc """ + Calculates precision, recall, F-score and support for each + class. It also supports a `beta` argument which weights + recall more than precision by it's value. + + ## Options + + #{NimbleOptions.docs(@precision_recall_fscore_support_schema)} + + ## Examples + + iex> y_true = Nx.tensor([0, 1, 1, 1, 1, 0, 2, 1, 0, 1], type: :u32) + iex> y_pred = Nx.tensor([0, 2, 1, 1, 2, 2, 2, 0, 0, 1], type: :u32) + iex> Scholar.Metrics.Classification.precision_recall_fscore_support(y_true, y_pred, num_classes: 3) + {Nx.f32([0.6666666865348816, 1.0, 0.25]), + Nx.f32([0.6666666865348816, 0.5, 1.0]), + Nx.f32([0.6666666865348816, 0.6666666865348816, 0.4000000059604645]), + Nx.u64([3, 6, 1])} + iex> Scholar.Metrics.Classification.precision_recall_fscore_support(y_true, y_pred, num_classes: 3, average: :macro) + {Nx.f32([0.6666666865348816, 1.0, 0.25]), + Nx.f32([0.6666666865348816, 0.5, 1.0]), + Nx.f32(0.5777778029441833), + Nx.Constants.nan()} + iex> Scholar.Metrics.Classification.precision_recall_fscore_support(y_true, y_pred, num_classes: 3, average: :weighted) + {Nx.f32([0.6666666865348816, 1.0, 0.25]), + Nx.f32([0.6666666865348816, 0.5, 1.0]), + Nx.f32(0.6399999856948853), + Nx.Constants.nan()} + iex> Scholar.Metrics.Classification.precision_recall_fscore_support(y_true, y_pred, num_classes: 3, average: :micro) + {Nx.f32(0.6000000238418579), + Nx.f32(0.6000000238418579), + Nx.f32(0.6000000238418579), + Nx.Constants.nan()} + + iex> y_true = Nx.tensor([1, 0, 1, 0], type: :u32) + iex> y_pred = Nx.tensor([0, 1, 0, 1], type: :u32) + iex> opts = [beta: 2, num_classes: 2, average: :none] + iex> Scholar.Metrics.Classification.precision_recall_fscore_support(y_true, y_pred, opts) + {Nx.f32([0.0, 0.0]), + Nx.f32([0.0, 0.0]), + Nx.f32([0.0, 0.0]), + Nx.u64([2, 2])} + """ + deftransform precision_recall_fscore_support(y_true, y_pred, opts) do + opts = NimbleOptions.validate!(opts, @precision_recall_fscore_support_schema) + {beta, opts} = Keyword.pop(opts, :beta) + + precision_recall_fscore_support_n( + y_true, + y_pred, + beta, + opts + ) + end + + defnp precision_recall_fscore_support_n(y_true, y_pred, beta, opts) do + check_shape(y_pred, y_true) + num_classes = check_num_classes(opts[:num_classes]) + average = opts[:average] + confusion_matrix = confusion_matrix(y_true, y_pred, num_classes: num_classes) {true_positive, false_positive, false_negative} = fbeta_score_v(confusion_matrix, average) @@ -700,13 +789,15 @@ defmodule Scholar.Metrics.Classification do case average do :none -> - {precision, recall, per_class_fscore} + support = (y_true == Nx.iota({num_classes, 1})) |> Nx.sum(axes: [1]) + + {precision, recall, per_class_fscore, support} :micro -> - {precision, recall, per_class_fscore} + {precision, recall, per_class_fscore, Nx.Constants.nan()} :macro -> - {precision, recall, Nx.mean(per_class_fscore)} + {precision, recall, Nx.mean(per_class_fscore), Nx.Constants.nan()} :weighted -> support = (y_true == Nx.iota({num_classes, 1})) |> Nx.sum(axes: [1]) @@ -716,7 +807,7 @@ defmodule Scholar.Metrics.Classification do |> safe_division(Nx.sum(support)) |> Nx.sum() - {precision, recall, per_class_fscore} + {precision, recall, per_class_fscore, Nx.Constants.nan()} end end diff --git a/lib/scholar/options.ex b/lib/scholar/options.ex index a779d572..e0173ad1 100644 --- a/lib/scholar/options.ex +++ b/lib/scholar/options.ex @@ -100,4 +100,12 @@ defmodule Scholar.Options do {:error, "expected metric to be a :cosine or tuple {:minkowski, p} where p is a positive number or :infinity, got: #{inspect(metric)}"} end + + def beta(beta) do + if (is_number(beta) and beta >= 0) or (Nx.is_tensor(beta) and Nx.rank(beta) == 0) do + {:ok, beta} + else + {:error, "expect 'beta' to be in the range [0, inf]"} + end + end end