Skip to content

Commit

Permalink
Add mean_tweedie_deviance and particular cases
Browse files Browse the repository at this point in the history
Particular cases:
- mean_poisson_deviance
- mean_gamma_deviance

And update mean_square_error to use mean_tweedie_deviance as well
  • Loading branch information
0urobor0s committed Oct 17, 2023
1 parent b36df2f commit bbcf5bb
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 2 deletions.
147 changes: 145 additions & 2 deletions lib/scholar/metrics/regression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ defmodule Scholar.Metrics.Regression do
>
"""
defn mean_square_error(y_true, y_pred) do
diff = y_true - y_pred
(diff * diff) |> Nx.mean()
mean_tweedie_deviance_n(y_true, y_pred, 0)
end

@doc ~S"""
Expand Down Expand Up @@ -133,6 +132,150 @@ defmodule Scholar.Metrics.Regression do
|> Nx.mean()
end

@doc """
Calculates the mean Tweedie deviance of predictions
with respect to targets. Includes the Gaussian, Poisson,
Gamma and inverse-Gaussian families as special cases.
#{~S'''
$$d(y,\mu) =
\begin{cases}
(y-\mu)^2, & \text{for }p=0\\\\
2(y \log(y/\mu) + \mu - y), & \text{for }p=1\\\\
2(\log(\mu/y) + y/\mu - 1), & \text{for }p=2\\\\
2\left(\frac{\max(y,0)^{2-p}}{(1-p)(2-p)}-\frac{y\mu^{1-p}}{1-p}+\frac{\mu^{2-p}}{2-p}\right), & \text{for }p<0 \vee p>2
\end{cases}$$
'''}
## Examples
iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32)
iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32)
iex> Scholar.Metrics.Regression.mean_tweedie_deviance(y_true, y_pred, 1)
#Nx.Tensor<
f32
0.18411168456077576
>
"""
deftransform mean_tweedie_deviance(y_true, y_pred, power) do
message = "Mean Tweedie deviance with power=#{power} can only be used on "

case check_tweedie_deviance_power(y_true, y_pred, power) |> Nx.to_number() do
2 -> raise message <> "strictly positive y_pred."
4 -> raise message <> "non-negative y_true and strictly positive y_pred."
5 -> raise message <> "strictly positive y_true and strictly positive y_pred."
100 -> raise "Something went wrong, branch should never appear."
1 -> :ok
end

mean_tweedie_deviance_n(y_true, y_pred, power)
end

defnp mean_tweedie_deviance_n(y_true, y_pred, power) do
deviance =
cond do
power < 0 ->
2 *
(
Nx.pow(max(y_true, 0), 2 - power) / ((1 - power) * (2 - power))
-y_true * Nx.pow(y_pred, 1 - power) / (1 - power)
+Nx.pow(y_pred, 2 - power) / (2 - power)
)

# Normal distribution
power == 0 ->
Nx.pow(y_true - y_pred, 2)

# Poisson distribution
power == 1 ->
2 * (y_true * Nx.log(y_true / y_pred) + y_pred - y_true)

# Gamma distribution
power == 2 ->
2 * (Nx.log(y_pred / y_true) + y_true / y_pred - 1)

# 1 < power < 2 -> Compound Poisson distribution, non-negative with mass at zero
# power == 3 -> Inverse-Gaussian distribution
# power > 2 -> Stable distribution, with support on the positive reals
true ->
2 *
(
Nx.pow(y_true, 2 - power) / ((1 - power) * (2 - power))
-y_true * Nx.pow(y_pred, 1 - power) / (1 - power)
+Nx.pow(y_pred, 2 - power) / (2 - power)
)
end

Nx.mean(deviance)
end

defn check_tweedie_deviance_power(y_true, y_pred, power) do
cond do
power < 0 ->
if Nx.all(y_pred > 0) do
Nx.u8(1)
else
Nx.u8(2)
end

power == 0 ->
Nx.u8(1)

power >= 1 and power < 2 ->
if Nx.all(y_true >= 0) and Nx.all(y_pred > 0) do
Nx.u8(1)
else
Nx.u8(4)
end

power >= 2 ->
if Nx.all(y_true > 0) and Nx.all(y_pred > 0) do
Nx.u8(1)
else
Nx.u8(5)
end

true ->
Nx.u8(100)
end
end

@doc ~S"""
Calculates the mean Poisson deviance of predictions
with respect to targets.
## Examples
iex> y_true = Nx.tensor([[0.0, 2.0], [0.5, 0.0]])
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 1.0]])
iex> Scholar.Metrics.Regression.mean_square_error(y_true, y_pred)
#Nx.Tensor<
f32
0.8125
>
"""
defn mean_poisson_deviance(y_true, y_pred) do
mean_tweedie_deviance_n(y_true, y_pred, 1)
end

@doc ~S"""
Calculates the mean Gamma deviance of predictions
with respect to targets.
## Examples
iex> y_true = Nx.tensor([[1.0, 2.0], [0.5, 2.0]])
iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 1.0]])
iex> Scholar.Metrics.Regression.mean_square_error(y_true, y_pred)
#Nx.Tensor<
f32
0.5625
>
"""
defn mean_gamma_deviance(y_true, y_pred) do
mean_tweedie_deviance_n(y_true, y_pred, 2)
end

@doc """
Calculates the $R^2$ score of predictions with respect to targets.
Expand Down
52 changes: 52 additions & 0 deletions test/scholar/metrics/regression_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,56 @@ defmodule Scholar.Metrics.RegressionTest do

alias Scholar.Metrics.Regression
doctest Regression

describe "mean_tweedie_deviance" do
test "raise when y_pred <= 0 and power < 0" do
power = -1
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32)

assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance(y_true, y_pred, power)
end
end

test "raise when y_pred <= 0 and 1 <= power < 2" do
power = 1
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32)

assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance(y_true, y_pred, power)
end
end

test "raise when y_pred <= 0 and power >= 2" do
power = 2
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32)

assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance(y_true, y_pred, power)
end
end

test "raise when y_true < 0 and 1 <= power < 2" do
power = 1
y_true = Nx.tensor([-1, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32)

assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance(y_true, y_pred, power)
end
end

test "raise when y_true <= 0 and power >= 2" do
power = 2
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32)

assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance(y_true, y_pred, power)
end
end
end
end

0 comments on commit bbcf5bb

Please sign in to comment.