From 162c530f2fae9a07e7f359b7d2c9b4e7d0e3610f Mon Sep 17 00:00:00 2001 From: Sean Gloumeau Date: Tue, 24 Oct 2023 23:12:10 +0200 Subject: [PATCH 1/8] Added log loss schema and skeleton functions --- lib/scholar/metrics/classification.ex | 40 +++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index 797d1cd6..e72cdb60 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -161,6 +161,25 @@ defmodule Scholar.Metrics.Classification do ] ] + log_loss_schema = + general_schema ++ [ + normalize: [ + type: :boolean, + default: true, + doc: """ + If `true`, return the mean loss over the samples. + Otherwise, return the sum of losses over the samples. + """ + ], + sample_weights: [ + type: {:custom, Scholar.Options, :weights, []}, + default: 1.0, + doc: """ + Sample weights of the observations. + """ + ] + ] + top_k_accuracy_score_schema = general_schema ++ [ @@ -203,6 +222,7 @@ defmodule Scholar.Metrics.Classification do ) @brier_score_loss_schema NimbleOptions.new!(brier_score_loss_schema) @accuracy_schema NimbleOptions.new!(accuracy_schema) + @log_loss_schema NimbleOptions.new!(log_loss_schema) @top_k_accuracy_score_schema NimbleOptions.new!(top_k_accuracy_score_schema) @zero_one_loss_schema NimbleOptions.new!(zero_one_loss_schema) @@ -1233,6 +1253,26 @@ defmodule Scholar.Metrics.Classification do 1 - Nx.sum(weights_matrix * cm) / Nx.sum(weights_matrix * expected) end + @doc """ + Computes the log loss, aka logistic loss or cross-entropy loss, of predictive + class probabilities given the true classes. + + ## Options + + #{NimbleOptions.docs(@log_loss_schema)} + """ + deftransform log_loss(y_true, y_prob, opts \\ []) do + log_loss_n( + y_true, + y_prob, + NimbleOptions.validate!(opts, @log_loss_schema) + ) + end + + defnp log_loss_n(y_true, y_prob, opts) do + y_true + end + @doc """ Top-k Accuracy classification score. From 9c1bcfab64f061833efb6c3699f570276df086d0 Mon Sep 17 00:00:00 2001 From: Sean Gloumeau Date: Tue, 24 Oct 2023 23:16:40 +0200 Subject: [PATCH 2/8] Added log loss computation assuming ideal inputs --- lib/scholar/metrics/classification.ex | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index e72cdb60..9ef67df8 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -13,6 +13,7 @@ defmodule Scholar.Metrics.Classification do import Nx.Defn, except: [assert_shape: 2, assert_shape_pattern: 2] import Scholar.Shared + import Scholar.Preprocessing alias Scholar.Integrate general_schema = [ @@ -1270,7 +1271,18 @@ defmodule Scholar.Metrics.Classification do end defnp log_loss_n(y_true, y_prob, opts) do - y_true + num_classes = opts[:num_classes] + y_true_onehot = + ordinal_encode(y_true, num_classes: num_classes) + |> one_hot_encode(num_classes: num_classes) + y_prob = Nx.clip(y_prob, 0, 1) + sample_loss = + Nx.multiply(y_true_onehot, y_prob) + |> Nx.sum(axes: [-1]) + |> Nx.log() + |> Nx.negate() + + Nx.mean(sample_loss) end @doc """ From 725149fc04124393b017a7cb829567f1dd40f13f Mon Sep 17 00:00:00 2001 From: Sean Gloumeau Date: Tue, 24 Oct 2023 23:19:16 +0200 Subject: [PATCH 3/8] Added argument checking --- lib/scholar/metrics/classification.ex | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index 9ef67df8..b7cb2058 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -1271,7 +1271,21 @@ defmodule Scholar.Metrics.Classification do end defnp log_loss_n(y_true, y_prob, opts) do + assert_rank!(y_true, 1) + assert_rank!(y_prob, 2) + if Nx.axis_size(y_true, 0) != Nx.axis_size(y_prob, 0) do + raise ArgumentError, "y_true and y_prob must have the same size along axis 0" + end num_classes = opts[:num_classes] + if Nx.axis_size(y_prob, 1) != num_classes do + raise ArgumentError, "y_prob must have a size of num_classes along axis 1" + end + weights = validate_weights( + opts[:sample_weights], + Nx.axis_size(y_true, 0), + type: to_float_type(y_prob) + ) + y_true_onehot = ordinal_encode(y_true, num_classes: num_classes) |> one_hot_encode(num_classes: num_classes) @@ -1282,7 +1296,12 @@ defmodule Scholar.Metrics.Classification do |> Nx.log() |> Nx.negate() - Nx.mean(sample_loss) + if opts[:normalize] do + Nx.weighted_mean(sample_loss, weights) + else + Nx.multiply(sample_loss, weights) + |> Nx.sum() + end end @doc """ From 5c633685d7439a7e06bb8e14315975ed6c76729f Mon Sep 17 00:00:00 2001 From: Sean Gloumeau Date: Tue, 24 Oct 2023 23:57:58 +0200 Subject: [PATCH 4/8] Added description, client expectations, and examples to function doc --- lib/scholar/metrics/classification.ex | 34 +++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index b7cb2058..f83440e1 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -1255,12 +1255,42 @@ defmodule Scholar.Metrics.Classification do end @doc """ - Computes the log loss, aka logistic loss or cross-entropy loss, of predictive - class probabilities given the true classes. + Computes the log loss, aka logistic loss or cross-entropy loss. + + The log-loss is a measure of how well a forecaster performs, with smaller + values being better. For each sample, a forecaster outputs a probability for + each class, from which the log loss is computed by averaging the negative log + of the probability forecasted for the true class over a number of samples. + + `y_true` should contain `num_classes` unique values, and the sum of `y_pred` + along axis 1 should be 1 to respect the law of total probability. ## Options #{NimbleOptions.docs(@log_loss_schema)} + + ## Examples + + iex> y_true = Nx.tensor([0, 0, 1, 1]) + iex> y_pred = Nx.tensor([[0.9, 0.1], [0.8, 0.2], [0.3, 0.7], [0.01, 0.99]]) + iex> Scholar.Metrics.Classification.log_loss(y_true, y_pred, num_classes: 2) + #Nx.Tensor< + f32 + 0.17380733788013458 + > + + iex> Scholar.Metrics.Classification.log_loss(y_true, y_pred, num_classes: 2, normalize: false) + #Nx.Tensor< + f32 + 0.6952293515205383 + > + + iex> weights = Nx.tensor([0.7, 2.3, 1.3, 0.34]) + iex(361)> Scholar.Metrics.Classification.log_loss(y_true, y_pred, num_classes: 2, sample_weights: weights) + #Nx.Tensor< + f32 + 0.22717177867889404 + > """ deftransform log_loss(y_true, y_prob, opts \\ []) do log_loss_n( From 629b948c87c18ddb462170ca8aac2dc9ce1aec95 Mon Sep 17 00:00:00 2001 From: Sean Gloumeau Date: Tue, 24 Oct 2023 23:58:44 +0200 Subject: [PATCH 5/8] Reformatted --- lib/scholar/metrics/classification.ex | 50 ++++++++++++++++----------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index f83440e1..a8f181c5 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -163,23 +163,24 @@ defmodule Scholar.Metrics.Classification do ] log_loss_schema = - general_schema ++ [ - normalize: [ - type: :boolean, - default: true, - doc: """ - If `true`, return the mean loss over the samples. - Otherwise, return the sum of losses over the samples. - """ - ], - sample_weights: [ - type: {:custom, Scholar.Options, :weights, []}, - default: 1.0, - doc: """ - Sample weights of the observations. - """ + general_schema ++ + [ + normalize: [ + type: :boolean, + default: true, + doc: """ + If `true`, return the mean loss over the samples. + Otherwise, return the sum of losses over the samples. + """ + ], + sample_weights: [ + type: {:custom, Scholar.Options, :weights, []}, + default: 1.0, + doc: """ + Sample weights of the observations. + """ + ] ] - ] top_k_accuracy_score_schema = general_schema ++ @@ -1303,23 +1304,30 @@ defmodule Scholar.Metrics.Classification do defnp log_loss_n(y_true, y_prob, opts) do assert_rank!(y_true, 1) assert_rank!(y_prob, 2) + if Nx.axis_size(y_true, 0) != Nx.axis_size(y_prob, 0) do raise ArgumentError, "y_true and y_prob must have the same size along axis 0" end + num_classes = opts[:num_classes] + if Nx.axis_size(y_prob, 1) != num_classes do raise ArgumentError, "y_prob must have a size of num_classes along axis 1" end - weights = validate_weights( - opts[:sample_weights], - Nx.axis_size(y_true, 0), - type: to_float_type(y_prob) - ) + + weights = + validate_weights( + opts[:sample_weights], + Nx.axis_size(y_true, 0), + type: to_float_type(y_prob) + ) y_true_onehot = ordinal_encode(y_true, num_classes: num_classes) |> one_hot_encode(num_classes: num_classes) + y_prob = Nx.clip(y_prob, 0, 1) + sample_loss = Nx.multiply(y_true_onehot, y_prob) |> Nx.sum(axes: [-1]) From 5497ac84b5c67a574e784b1bea2eda59f1d009b5 Mon Sep 17 00:00:00 2001 From: Sean Gloumeau Date: Thu, 26 Oct 2023 11:20:01 +0200 Subject: [PATCH 6/8] Fixed typos --- lib/scholar/metrics/classification.ex | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index a8f181c5..26258848 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -1263,7 +1263,7 @@ defmodule Scholar.Metrics.Classification do each class, from which the log loss is computed by averaging the negative log of the probability forecasted for the true class over a number of samples. - `y_true` should contain `num_classes` unique values, and the sum of `y_pred` + `y_true` should contain `num_classes` unique values, and the sum of `y_prob` along axis 1 should be 1 to respect the law of total probability. ## Options @@ -1273,21 +1273,21 @@ defmodule Scholar.Metrics.Classification do ## Examples iex> y_true = Nx.tensor([0, 0, 1, 1]) - iex> y_pred = Nx.tensor([[0.9, 0.1], [0.8, 0.2], [0.3, 0.7], [0.01, 0.99]]) - iex> Scholar.Metrics.Classification.log_loss(y_true, y_pred, num_classes: 2) + iex> y_prob = Nx.tensor([[0.9, 0.1], [0.8, 0.2], [0.3, 0.7], [0.01, 0.99]]) + iex> Scholar.Metrics.Classification.log_loss(y_true, y_prob, num_classes: 2) #Nx.Tensor< f32 0.17380733788013458 > - iex> Scholar.Metrics.Classification.log_loss(y_true, y_pred, num_classes: 2, normalize: false) + iex> Scholar.Metrics.Classification.log_loss(y_true, y_prob, num_classes: 2, normalize: false) #Nx.Tensor< f32 0.6952293515205383 > iex> weights = Nx.tensor([0.7, 2.3, 1.3, 0.34]) - iex(361)> Scholar.Metrics.Classification.log_loss(y_true, y_pred, num_classes: 2, sample_weights: weights) + iex(361)> Scholar.Metrics.Classification.log_loss(y_true, y_prob, num_classes: 2, sample_weights: weights) #Nx.Tensor< f32 0.22717177867889404 From 69f5f9f1a34165373ece28846d7d6712c2a2c7a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Thu, 26 Oct 2023 13:20:46 +0200 Subject: [PATCH 7/8] Update lib/scholar/metrics/classification.ex --- lib/scholar/metrics/classification.ex | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index 26258848..d3c9391e 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -1279,7 +1279,6 @@ defmodule Scholar.Metrics.Classification do f32 0.17380733788013458 > - iex> Scholar.Metrics.Classification.log_loss(y_true, y_prob, num_classes: 2, normalize: false) #Nx.Tensor< f32 From 025bb853b78fcf3252f70aaa4c05a143f601794e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Thu, 26 Oct 2023 13:21:02 +0200 Subject: [PATCH 8/8] Update lib/scholar/metrics/classification.ex --- lib/scholar/metrics/classification.ex | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/scholar/metrics/classification.ex b/lib/scholar/metrics/classification.ex index d3c9391e..64291aa2 100644 --- a/lib/scholar/metrics/classification.ex +++ b/lib/scholar/metrics/classification.ex @@ -1284,7 +1284,6 @@ defmodule Scholar.Metrics.Classification do f32 0.6952293515205383 > - iex> weights = Nx.tensor([0.7, 2.3, 1.3, 0.34]) iex(361)> Scholar.Metrics.Classification.log_loss(y_true, y_prob, num_classes: 2, sample_weights: weights) #Nx.Tensor<