diff --git a/lib/scholar/metrics/regression.ex b/lib/scholar/metrics/regression.ex index 725f7b99..3d61f6c6 100644 --- a/lib/scholar/metrics/regression.ex +++ b/lib/scholar/metrics/regression.ex @@ -71,8 +71,7 @@ defmodule Scholar.Metrics.Regression do > """ defn mean_square_error(y_true, y_pred) do - diff = y_true - y_pred - (diff * diff) |> Nx.mean() + mean_tweedie_deviance_n(y_true, y_pred, 0) end @doc ~S""" @@ -133,6 +132,150 @@ defmodule Scholar.Metrics.Regression do |> Nx.mean() end + @doc """ + Calculates the mean Tweedie deviance of predictions + with respect to targets. Includes the Gaussian, Poisson, + Gamma and inverse-Gaussian families as special cases. + + #{~S''' + $$d(y,\mu) = + \begin{cases} + (y-\mu)^2, & \text{for }p=0\\\\ + 2(y \log(y/\mu) + \mu - y), & \text{for }p=1\\\\ + 2(\log(\mu/y) + y/\mu - 1), & \text{for }p=2\\\\ + 2\left(\frac{\max(y,0)^{2-p}}{(1-p)(2-p)}-\frac{y\mu^{1-p}}{1-p}+\frac{\mu^{2-p}}{2-p}\right), & \text{for }p<0 \vee p>2 + \end{cases}$$ + '''} + + ## Examples + + iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) + iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32) + iex> Scholar.Metrics.Regression.mean_tweedie_deviance(y_true, y_pred, 1) + #Nx.Tensor< + f32 + 0.18411168456077576 + > + """ + deftransform mean_tweedie_deviance(y_true, y_pred, power) do + message = "Mean Tweedie deviance with power=#{power} can only be used on " + + case check_tweedie_deviance_power(y_true, y_pred, power) |> Nx.to_number() do + 2 -> raise message <> "strictly positive y_pred." + 4 -> raise message <> "non-negative y_true and strictly positive y_pred." + 5 -> raise message <> "strictly positive y_true and strictly positive y_pred." + 100 -> raise "Something went wrong, branch should never appear." + 1 -> :ok + end + + mean_tweedie_deviance_n(y_true, y_pred, power) + end + + defnp mean_tweedie_deviance_n(y_true, y_pred, power) do + deviance = + cond do + power < 0 -> + 2 * + ( + Nx.pow(max(y_true, 0), 2 - power) / ((1 - power) * (2 - power)) + -y_true * Nx.pow(y_pred, 1 - power) / (1 - power) + +Nx.pow(y_pred, 2 - power) / (2 - power) + ) + + # Normal distribution + power == 0 -> + Nx.pow(y_true - y_pred, 2) + + # Poisson distribution + power == 1 -> + 2 * (y_true * Nx.log(y_true / y_pred) + y_pred - y_true) + + # Gamma distribution + power == 2 -> + 2 * (Nx.log(y_pred / y_true) + y_true / y_pred - 1) + + # 1 < power < 2 -> Compound Poisson distribution, non-negative with mass at zero + # power == 3 -> Inverse-Gaussian distribution + # power > 2 -> Stable distribution, with support on the positive reals + true -> + 2 * + ( + Nx.pow(y_true, 2 - power) / ((1 - power) * (2 - power)) + -y_true * Nx.pow(y_pred, 1 - power) / (1 - power) + +Nx.pow(y_pred, 2 - power) / (2 - power) + ) + end + + Nx.mean(deviance) + end + + defn check_tweedie_deviance_power(y_true, y_pred, power) do + cond do + power < 0 -> + if Nx.all(y_pred > 0) do + Nx.u8(1) + else + Nx.u8(2) + end + + power == 0 -> + Nx.u8(1) + + power >= 1 and power < 2 -> + if Nx.all(y_true >= 0) and Nx.all(y_pred > 0) do + Nx.u8(1) + else + Nx.u8(4) + end + + power >= 2 -> + if Nx.all(y_true > 0) and Nx.all(y_pred > 0) do + Nx.u8(1) + else + Nx.u8(5) + end + + true -> + Nx.u8(100) + end + end + + @doc ~S""" + Calculates the mean Poisson deviance of predictions + with respect to targets. + + ## Examples + + iex> y_true = Nx.tensor([[0.0, 2.0], [0.5, 0.0]]) + iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 1.0]]) + iex> Scholar.Metrics.Regression.mean_square_error(y_true, y_pred) + #Nx.Tensor< + f32 + 0.8125 + > + """ + defn mean_poisson_deviance(y_true, y_pred) do + mean_tweedie_deviance_n(y_true, y_pred, 1) + end + + @doc ~S""" + Calculates the mean Gamma deviance of predictions + with respect to targets. + + ## Examples + + iex> y_true = Nx.tensor([[1.0, 2.0], [0.5, 2.0]]) + iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 1.0]]) + iex> Scholar.Metrics.Regression.mean_square_error(y_true, y_pred) + #Nx.Tensor< + f32 + 0.5625 + > + """ + defn mean_gamma_deviance(y_true, y_pred) do + mean_tweedie_deviance_n(y_true, y_pred, 2) + end + @doc """ Calculates the $R^2$ score of predictions with respect to targets. diff --git a/test/scholar/metrics/regression_test.exs b/test/scholar/metrics/regression_test.exs index 68258382..d5711ce3 100644 --- a/test/scholar/metrics/regression_test.exs +++ b/test/scholar/metrics/regression_test.exs @@ -3,4 +3,56 @@ defmodule Scholar.Metrics.RegressionTest do alias Scholar.Metrics.Regression doctest Regression + + describe "mean_tweedie_deviance" do + test "raise when y_pred <= 0 and power < 0" do + power = -1 + y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) + y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) + + assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance(y_true, y_pred, power) + end + end + + test "raise when y_pred <= 0 and 1 <= power < 2" do + power = 1 + y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) + y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) + + assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance(y_true, y_pred, power) + end + end + + test "raise when y_pred <= 0 and power >= 2" do + power = 2 + y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) + y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) + + assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance(y_true, y_pred, power) + end + end + + test "raise when y_true < 0 and 1 <= power < 2" do + power = 1 + y_true = Nx.tensor([-1, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32) + y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32) + + assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance(y_true, y_pred, power) + end + end + + test "raise when y_true <= 0 and power >= 2" do + power = 2 + y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32) + y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32) + + assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance(y_true, y_pred, power) + end + end + end end