Skip to content

Commit

Permalink
Add d2_tweedie_score (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
0urobor0s authored Oct 27, 2023
1 parent cbccaa3 commit 3793cd5
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
40 changes: 40 additions & 0 deletions lib/scholar/metrics/regression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,46 @@ defmodule Scholar.Metrics.Regression do
end
end

@doc """
$D^2$ regression score function, fraction of Tweedie
deviance explained.
Best possible score is 1.0, lower values are worse and it
can also be negative.
Since it uses the mean Tweedie deviance, it also includes
the Gaussian, Poisson, Gamma and inverse-Gaussian
distribution families as special cases.
## Examples
iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32)
iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32)
iex> Scholar.Metrics.Regression.d2_tweedie_score(y_true, y_pred, 1)
#Nx.Tensor<
f32
0.32202935218811035
>
"""
defn d2_tweedie_score(y_true, y_pred, power) do
if Nx.size(y_pred) < 2 do
Nx.Constants.nan()
else
d2_tweedie_score_n(y_true, y_pred, power)
end
end

defnp d2_tweedie_score_n(y_true, y_pred, power) do
y_true = Nx.squeeze(y_true)
y_pred = Nx.squeeze(y_pred)

numerator = mean_tweedie_deviance_n(y_true, y_pred, power)
y_avg = Nx.mean(y_true)
denominator = mean_tweedie_deviance_n(y_true, y_avg, power)

1 - numerator / denominator
end

@doc ~S"""
Calculates the maximum residual error.
Expand Down
11 changes: 11 additions & 0 deletions test/scholar/metrics/regression_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,15 @@ defmodule Scholar.Metrics.RegressionTest do
end
end
end

describe "d2_tweedie_score/3" do
test "equal R^2 when power is 0" do
y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32)
y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32)
d2 = Regression.d2_tweedie_score(y_true, y_pred, 0)
r2 = Regression.r2_score(y_true, y_pred)

assert Nx.equal(d2, r2)
end
end
end

0 comments on commit 3793cd5

Please sign in to comment.