Skip to content

Commit

Permalink
Added log loss classification metric (elixir-nx#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
svarqq authored Oct 26, 2023
1 parent 33d3bd6 commit cf8bbe4
Showing 1 changed file with 107 additions and 0 deletions.
107 changes: 107 additions & 0 deletions lib/scholar/metrics/classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ defmodule Scholar.Metrics.Classification do

import Nx.Defn, except: [assert_shape: 2, assert_shape_pattern: 2]
import Scholar.Shared
import Scholar.Preprocessing
alias Scholar.Integrate

general_schema = [
Expand Down Expand Up @@ -161,6 +162,26 @@ defmodule Scholar.Metrics.Classification do
]
]

log_loss_schema =
general_schema ++
[
normalize: [
type: :boolean,
default: true,
doc: """
If `true`, return the mean loss over the samples.
Otherwise, return the sum of losses over the samples.
"""
],
sample_weights: [
type: {:custom, Scholar.Options, :weights, []},
default: 1.0,
doc: """
Sample weights of the observations.
"""
]
]

top_k_accuracy_score_schema =
general_schema ++
[
Expand Down Expand Up @@ -203,6 +224,7 @@ defmodule Scholar.Metrics.Classification do
)
@brier_score_loss_schema NimbleOptions.new!(brier_score_loss_schema)
@accuracy_schema NimbleOptions.new!(accuracy_schema)
@log_loss_schema NimbleOptions.new!(log_loss_schema)
@top_k_accuracy_score_schema NimbleOptions.new!(top_k_accuracy_score_schema)
@zero_one_loss_schema NimbleOptions.new!(zero_one_loss_schema)

Expand Down Expand Up @@ -1233,6 +1255,91 @@ defmodule Scholar.Metrics.Classification do
1 - Nx.sum(weights_matrix * cm) / Nx.sum(weights_matrix * expected)
end

@doc """
Computes the log loss, aka logistic loss or cross-entropy loss.
The log-loss is a measure of how well a forecaster performs, with smaller
values being better. For each sample, a forecaster outputs a probability for
each class, from which the log loss is computed by averaging the negative log
of the probability forecasted for the true class over a number of samples.
`y_true` should contain `num_classes` unique values, and the sum of `y_prob`
along axis 1 should be 1 to respect the law of total probability.
## Options
#{NimbleOptions.docs(@log_loss_schema)}
## Examples
iex> y_true = Nx.tensor([0, 0, 1, 1])
iex> y_prob = Nx.tensor([[0.9, 0.1], [0.8, 0.2], [0.3, 0.7], [0.01, 0.99]])
iex> Scholar.Metrics.Classification.log_loss(y_true, y_prob, num_classes: 2)
#Nx.Tensor<
f32
0.17380733788013458
>
iex> Scholar.Metrics.Classification.log_loss(y_true, y_prob, num_classes: 2, normalize: false)
#Nx.Tensor<
f32
0.6952293515205383
>
iex> weights = Nx.tensor([0.7, 2.3, 1.3, 0.34])
iex(361)> Scholar.Metrics.Classification.log_loss(y_true, y_prob, num_classes: 2, sample_weights: weights)
#Nx.Tensor<
f32
0.22717177867889404
>
"""
deftransform log_loss(y_true, y_prob, opts \\ []) do
log_loss_n(
y_true,
y_prob,
NimbleOptions.validate!(opts, @log_loss_schema)
)
end

defnp log_loss_n(y_true, y_prob, opts) do
assert_rank!(y_true, 1)
assert_rank!(y_prob, 2)

if Nx.axis_size(y_true, 0) != Nx.axis_size(y_prob, 0) do
raise ArgumentError, "y_true and y_prob must have the same size along axis 0"
end

num_classes = opts[:num_classes]

if Nx.axis_size(y_prob, 1) != num_classes do
raise ArgumentError, "y_prob must have a size of num_classes along axis 1"
end

weights =
validate_weights(
opts[:sample_weights],
Nx.axis_size(y_true, 0),
type: to_float_type(y_prob)
)

y_true_onehot =
ordinal_encode(y_true, num_classes: num_classes)
|> one_hot_encode(num_classes: num_classes)

y_prob = Nx.clip(y_prob, 0, 1)

sample_loss =
Nx.multiply(y_true_onehot, y_prob)
|> Nx.sum(axes: [-1])
|> Nx.log()
|> Nx.negate()

if opts[:normalize] do
Nx.weighted_mean(sample_loss, weights)
else
Nx.multiply(sample_loss, weights)
|> Nx.sum()
end
end

@doc """
Top-k Accuracy classification score.
Expand Down

0 comments on commit cf8bbe4

Please sign in to comment.