Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mean_tweedie_deviance and particular cases #193

Merged
merged 5 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 201 additions & 2 deletions lib/scholar/metrics/regression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ defmodule Scholar.Metrics.Regression do
any supported `Nx` compiler.
"""

import Nx, only: [is_tensor: 1]
import Nx.Defn, except: [assert_shape: 2, assert_shape_pattern: 2]
import Scholar.Shared
import Scholar.Metrics.Distance
Expand All @@ -26,7 +27,24 @@ defmodule Scholar.Metrics.Regression do
]
]

mean_tweedie_deviance_schema = [
check_tensors: [
type: :boolean,
default: false,
doc: """
Flag indicating if tensor inputs should be checked to conform with the
necessary properties for the given power value.
"""
]
]

mean_poisson_deviance_schema = mean_tweedie_deviance_schema
mean_gamma_deviance_schema = mean_tweedie_deviance_schema

@r2_schema NimbleOptions.new!(r2_schema)
@mean_tweedie_deviance_schema NimbleOptions.new!(mean_tweedie_deviance_schema)
@mean_poisson_deviance_schema NimbleOptions.new!(mean_poisson_deviance_schema)
@mean_gamma_deviance_schema NimbleOptions.new!(mean_gamma_deviance_schema)

# Standard Metrics

Expand Down Expand Up @@ -71,8 +89,7 @@ defmodule Scholar.Metrics.Regression do
>
"""
defn mean_square_error(y_true, y_pred) do
diff = y_true - y_pred
(diff * diff) |> Nx.mean()
mean_tweedie_deviance_n(y_true, y_pred, 0)
end

@doc ~S"""
Expand Down Expand Up @@ -133,6 +150,188 @@ defmodule Scholar.Metrics.Regression do
|> Nx.mean()
end

@doc """
Calculates the mean Tweedie deviance of predictions
with respect to targets. Includes the Gaussian, Poisson,
Gamma and inverse-Gaussian families as special cases.

#{~S'''
$$d(y,\mu) =
\begin{cases}
(y-\mu)^2, & \text{for }p=0\\\\
2(y \log(y/\mu) + \mu - y), & \text{for }p=1\\\\
2(\log(\mu/y) + y/\mu - 1), & \text{for }p=2\\\\
2\left(\frac{\max(y,0)^{2-p}}{(1-p)(2-p)}-\frac{y\mu^{1-p}}{1-p}+\frac{\mu^{2-p}}{2-p}\right), & \text{for }p<0 \vee p>2
\end{cases}$$
'''}

## Options

#{NimbleOptions.docs(@mean_tweedie_deviance_schema)}

## Examples

iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32)
iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32)
iex> Scholar.Metrics.Regression.mean_tweedie_deviance(y_true, y_pred, 1)
#Nx.Tensor<
f32
0.18411168456077576
>
"""
deftransform mean_tweedie_deviance(y_true, y_pred, power, opts \\ []) do
opts = NimbleOptions.validate!(opts, @mean_tweedie_deviance_schema)

if opts[:check_tensors] do
check_tweedie_deviance_power(y_true, y_pred, power)
end

mean_tweedie_deviance_n(y_true, y_pred, power)
end

defnp mean_tweedie_deviance_n(y_true, y_pred, power) do
deviance =
cond do
power < 0 ->
2 *
(
Nx.pow(max(y_true, 0), 2 - power) / ((1 - power) * (2 - power))
-y_true * Nx.pow(y_pred, 1 - power) / (1 - power)
+Nx.pow(y_pred, 2 - power) / (2 - power)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this clause as it is the same as the last?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, they are not the same. In the first one, there is max(y_true,0)


# Normal distribution
power == 0 ->
Nx.pow(y_true - y_pred, 2)

# Poisson distribution
power == 1 ->
2 * (y_true * Nx.log(y_true / y_pred) + y_pred - y_true)

# Gamma distribution
power == 2 ->
2 * (Nx.log(y_pred / y_true) + y_true / y_pred - 1)

# 1 < power < 2 -> Compound Poisson distribution, non-negative with mass at zero
# power == 3 -> Inverse-Gaussian distribution
# power > 2 -> Stable distribution, with support on the positive reals
true ->
2 *
(
Nx.pow(y_true, 2 - power) / ((1 - power) * (2 - power))
-y_true * Nx.pow(y_pred, 1 - power) / (1 - power)
+Nx.pow(y_pred, 2 - power) / (2 - power)
)
end

Nx.mean(deviance)
end

defp check_tweedie_deviance_power(y_true, y_pred, power) when is_number(power) do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I liked your previous implementation more, where you did the checking in Nx, and returned an integer. :)

message = "Mean Tweedie deviance with power=#{power} can only be used on "

cond do
power < 0 ->
if nx_to_bool(Nx.greater(y_pred, 0)) do
:ok
else
raise message <> "strictly positive y_pred."
end

power == 0 ->
:ok

power >= 1 and power < 2 ->
if nx_to_bool(Nx.greater_equal(y_true, 0)) and nx_to_bool(Nx.greater(y_pred, 0)) do
:ok
else
raise message <> "non-negative y_true and strictly positive y_pred."
end

power >= 2 ->
if nx_to_bool(Nx.greater(y_true, 0)) and nx_to_bool(Nx.greater(y_pred, 0)) do
:ok
else
raise message <> "strictly positive y_true and strictly positive y_pred."
end

true ->
raise "Something went wrong, branch should never appear."
0urobor0s marked this conversation as resolved.
Show resolved Hide resolved
end
end

defp check_tweedie_deviance_power(y_true, y_pred, power) when is_tensor(power) do
check_tweedie_deviance_power(y_true, y_pred, Nx.to_number(power))
end

defp check_tweedie_deviance_power(y_true, y_pred, :neg_infinity) do
# Same math function check
check_tweedie_deviance_power(y_true, y_pred, -1)
end

defp check_tweedie_deviance_power(y_true, y_pred, :infinity) do
# Same math function check
check_tweedie_deviance_power(y_true, y_pred, 2)
end

defp check_tweedie_deviance_power(_y_true, _y_pred, :nan) do
raise "NaN is not supported."
end

defp nx_to_bool(tensor) when is_tensor(tensor) do
tensor
|> Nx.all()
|> Nx.to_number()
|> case do
0 -> false
1 -> true
end
end

@doc """
Calculates the mean Poisson deviance of predictions
with respect to targets.

## Options

#{NimbleOptions.docs(@mean_poisson_deviance_schema)}

## Examples

iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32)
iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32)
iex> Scholar.Metrics.Regression.mean_poisson_deviance(y_true, y_pred)
#Nx.Tensor<
f32
0.18411168456077576
>
"""
deftransform mean_poisson_deviance(y_true, y_pred, opts \\ []) do
mean_tweedie_deviance(y_true, y_pred, 1, opts)
end

@doc """
Calculates the mean Gamma deviance of predictions
with respect to targets.

## Options

#{NimbleOptions.docs(@mean_gamma_deviance_schema)}

## Examples

iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32)
iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32)
iex> Scholar.Metrics.Regression.mean_gamma_deviance(y_true, y_pred)
#Nx.Tensor<
f32
0.115888312458992
>
"""
deftransform mean_gamma_deviance(y_true, y_pred, opts \\ []) do
mean_tweedie_deviance(y_true, y_pred, 2, opts)
end

@doc """
Calculates the $R^2$ score of predictions with respect to targets.

Expand Down
52 changes: 52 additions & 0 deletions test/scholar/metrics/regression_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,56 @@ defmodule Scholar.Metrics.RegressionTest do

alias Scholar.Metrics.Regression
doctest Regression

describe "mean_tweedie_deviance" do
test "raise when y_pred <= 0 and power < 0 with check_tensors: true" do
power = -1
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32)

assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true)
end
end

test "raise when y_pred <= 0 and 1 <= power < 2 with check_tensors: true" do
power = 1
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32)

assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true)
end
end

test "raise when y_pred <= 0 and power >= 2 with check_tensors: true" do
power = 2
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32)

assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true)
end
end

test "raise when y_true < 0 and 1 <= power < 2 with check_tensors: true" do
power = 1
y_true = Nx.tensor([-1, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32)

assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true)
end
end

test "raise when y_true <= 0 and power >= 2 with check_tensors: true" do
power = 2
y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32)
y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32)

assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn ->
Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true)
end
end
end
end
Loading