From bbcf5bb7c5f9b47658951dd64392104fdf8f68f8 Mon Sep 17 00:00:00 2001 From: 0urobor0s <0urobor0s@users.noreply.github.com> Date: Tue, 17 Oct 2023 14:40:59 +0100 Subject: [PATCH 1/5] Add mean_tweedie_deviance and particular cases Particular cases: - mean_poisson_deviance - mean_gamma_deviance And update mean_square_error to use mean_tweedie_deviance as well --- lib/scholar/metrics/regression.ex | 147 ++++++++++++++++++++++- test/scholar/metrics/regression_test.exs | 52 ++++++++ 2 files changed, 197 insertions(+), 2 deletions(-) diff --git a/lib/scholar/metrics/regression.ex b/lib/scholar/metrics/regression.ex index 725f7b99..3d61f6c6 100644 --- a/lib/scholar/metrics/regression.ex +++ b/lib/scholar/metrics/regression.ex @@ -71,8 +71,7 @@ defmodule Scholar.Metrics.Regression do > """ defn mean_square_error(y_true, y_pred) do - diff = y_true - y_pred - (diff * diff) |> Nx.mean() + mean_tweedie_deviance_n(y_true, y_pred, 0) end @doc ~S""" @@ -133,6 +132,150 @@ defmodule Scholar.Metrics.Regression do |> Nx.mean() end + @doc """ + Calculates the mean Tweedie deviance of predictions + with respect to targets. Includes the Gaussian, Poisson, + Gamma and inverse-Gaussian families as special cases. + + #{~S''' + $$d(y,\mu) = + \begin{cases} + (y-\mu)^2, & \text{for }p=0\\\\ + 2(y \log(y/\mu) + \mu - y), & \text{for }p=1\\\\ + 2(\log(\mu/y) + y/\mu - 1), & \text{for }p=2\\\\ + 2\left(\frac{\max(y,0)^{2-p}}{(1-p)(2-p)}-\frac{y\mu^{1-p}}{1-p}+\frac{\mu^{2-p}}{2-p}\right), & \text{for }p<0 \vee p>2 + \end{cases}$$ + '''} + + ## Examples + + iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) + iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32) + iex> Scholar.Metrics.Regression.mean_tweedie_deviance(y_true, y_pred, 1) + #Nx.Tensor< + f32 + 0.18411168456077576 + > + """ + deftransform mean_tweedie_deviance(y_true, y_pred, power) do + message = "Mean Tweedie deviance with power=#{power} can only be used on " + + case check_tweedie_deviance_power(y_true, y_pred, power) |> Nx.to_number() do + 2 -> raise message <> "strictly positive y_pred." + 4 -> raise message <> "non-negative y_true and strictly positive y_pred." + 5 -> raise message <> "strictly positive y_true and strictly positive y_pred." + 100 -> raise "Something went wrong, branch should never appear." + 1 -> :ok + end + + mean_tweedie_deviance_n(y_true, y_pred, power) + end + + defnp mean_tweedie_deviance_n(y_true, y_pred, power) do + deviance = + cond do + power < 0 -> + 2 * + ( + Nx.pow(max(y_true, 0), 2 - power) / ((1 - power) * (2 - power)) + -y_true * Nx.pow(y_pred, 1 - power) / (1 - power) + +Nx.pow(y_pred, 2 - power) / (2 - power) + ) + + # Normal distribution + power == 0 -> + Nx.pow(y_true - y_pred, 2) + + # Poisson distribution + power == 1 -> + 2 * (y_true * Nx.log(y_true / y_pred) + y_pred - y_true) + + # Gamma distribution + power == 2 -> + 2 * (Nx.log(y_pred / y_true) + y_true / y_pred - 1) + + # 1 < power < 2 -> Compound Poisson distribution, non-negative with mass at zero + # power == 3 -> Inverse-Gaussian distribution + # power > 2 -> Stable distribution, with support on the positive reals + true -> + 2 * + ( + Nx.pow(y_true, 2 - power) / ((1 - power) * (2 - power)) + -y_true * Nx.pow(y_pred, 1 - power) / (1 - power) + +Nx.pow(y_pred, 2 - power) / (2 - power) + ) + end + + Nx.mean(deviance) + end + + defn check_tweedie_deviance_power(y_true, y_pred, power) do + cond do + power < 0 -> + if Nx.all(y_pred > 0) do + Nx.u8(1) + else + Nx.u8(2) + end + + power == 0 -> + Nx.u8(1) + + power >= 1 and power < 2 -> + if Nx.all(y_true >= 0) and Nx.all(y_pred > 0) do + Nx.u8(1) + else + Nx.u8(4) + end + + power >= 2 -> + if Nx.all(y_true > 0) and Nx.all(y_pred > 0) do + Nx.u8(1) + else + Nx.u8(5) + end + + true -> + Nx.u8(100) + end + end + + @doc ~S""" + Calculates the mean Poisson deviance of predictions + with respect to targets. + + ## Examples + + iex> y_true = Nx.tensor([[0.0, 2.0], [0.5, 0.0]]) + iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 1.0]]) + iex> Scholar.Metrics.Regression.mean_square_error(y_true, y_pred) + #Nx.Tensor< + f32 + 0.8125 + > + """ + defn mean_poisson_deviance(y_true, y_pred) do + mean_tweedie_deviance_n(y_true, y_pred, 1) + end + + @doc ~S""" + Calculates the mean Gamma deviance of predictions + with respect to targets. + + ## Examples + + iex> y_true = Nx.tensor([[1.0, 2.0], [0.5, 2.0]]) + iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 1.0]]) + iex> Scholar.Metrics.Regression.mean_square_error(y_true, y_pred) + #Nx.Tensor< + f32 + 0.5625 + > + """ + defn mean_gamma_deviance(y_true, y_pred) do + mean_tweedie_deviance_n(y_true, y_pred, 2) + end + @doc """ Calculates the $R^2$ score of predictions with respect to targets. diff --git a/test/scholar/metrics/regression_test.exs b/test/scholar/metrics/regression_test.exs index 68258382..d5711ce3 100644 --- a/test/scholar/metrics/regression_test.exs +++ b/test/scholar/metrics/regression_test.exs @@ -3,4 +3,56 @@ defmodule Scholar.Metrics.RegressionTest do alias Scholar.Metrics.Regression doctest Regression + + describe "mean_tweedie_deviance" do + test "raise when y_pred <= 0 and power < 0" do + power = -1 + y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) + y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) + + assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance(y_true, y_pred, power) + end + end + + test "raise when y_pred <= 0 and 1 <= power < 2" do + power = 1 + y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) + y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) + + assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance(y_true, y_pred, power) + end + end + + test "raise when y_pred <= 0 and power >= 2" do + power = 2 + y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) + y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) + + assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance(y_true, y_pred, power) + end + end + + test "raise when y_true < 0 and 1 <= power < 2" do + power = 1 + y_true = Nx.tensor([-1, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32) + y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32) + + assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance(y_true, y_pred, power) + end + end + + test "raise when y_true <= 0 and power >= 2" do + power = 2 + y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32) + y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32) + + assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance(y_true, y_pred, power) + end + end + end end From a25d537eed6fe98757ca62924ac8e700fa9fcd99 Mon Sep 17 00:00:00 2001 From: 0urobor0s <0urobor0s@users.noreply.github.com> Date: Wed, 18 Oct 2023 17:18:43 +0100 Subject: [PATCH 2/5] Use def for tensor check and make the check optional --- lib/scholar/metrics/regression.ex | 124 ++++++++++++++++------- test/scholar/metrics/regression_test.exs | 20 ++-- 2 files changed, 100 insertions(+), 44 deletions(-) diff --git a/lib/scholar/metrics/regression.ex b/lib/scholar/metrics/regression.ex index 3d61f6c6..37ebec5b 100644 --- a/lib/scholar/metrics/regression.ex +++ b/lib/scholar/metrics/regression.ex @@ -11,6 +11,7 @@ defmodule Scholar.Metrics.Regression do any supported `Nx` compiler. """ + import Nx, only: [is_tensor: 1] import Nx.Defn, except: [assert_shape: 2, assert_shape_pattern: 2] import Scholar.Shared import Scholar.Metrics.Distance @@ -26,7 +27,24 @@ defmodule Scholar.Metrics.Regression do ] ] + mean_tweedie_deviance_schema = [ + check_tensors: [ + type: :boolean, + default: false, + doc: """ + Flag indicating if tensor inputs should be checked to conform with the + necessary properties for the given power value. + """ + ] + ] + + mean_poisson_deviance_schema = mean_tweedie_deviance_schema + mean_gamma_deviance_schema = mean_tweedie_deviance_schema + @r2_schema NimbleOptions.new!(r2_schema) + @mean_tweedie_deviance_schema NimbleOptions.new!(mean_tweedie_deviance_schema) + @mean_poisson_deviance_schema NimbleOptions.new!(mean_poisson_deviance_schema) + @mean_gamma_deviance_schema NimbleOptions.new!(mean_gamma_deviance_schema) # Standard Metrics @@ -147,6 +165,10 @@ defmodule Scholar.Metrics.Regression do \end{cases}$$ '''} + ## Options + + #{NimbleOptions.docs(@mean_tweedie_deviance_schema)} + ## Examples iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) @@ -157,15 +179,11 @@ defmodule Scholar.Metrics.Regression do 0.18411168456077576 > """ - deftransform mean_tweedie_deviance(y_true, y_pred, power) do - message = "Mean Tweedie deviance with power=#{power} can only be used on " + deftransform mean_tweedie_deviance(y_true, y_pred, power, opts \\ []) do + opts = NimbleOptions.validate!(opts, @mean_tweedie_deviance_schema) - case check_tweedie_deviance_power(y_true, y_pred, power) |> Nx.to_number() do - 2 -> raise message <> "strictly positive y_pred." - 4 -> raise message <> "non-negative y_true and strictly positive y_pred." - 5 -> raise message <> "strictly positive y_true and strictly positive y_pred." - 100 -> raise "Something went wrong, branch should never appear." - 1 -> :ok + if opts[:check_tensors] do + check_tweedie_deviance_power(y_true, y_pred, power) end mean_tweedie_deviance_n(y_true, y_pred, power) @@ -209,71 +227,109 @@ defmodule Scholar.Metrics.Regression do Nx.mean(deviance) end - defn check_tweedie_deviance_power(y_true, y_pred, power) do + defp check_tweedie_deviance_power(y_true, y_pred, power) when is_number(power) do + message = "Mean Tweedie deviance with power=#{power} can only be used on " + cond do power < 0 -> - if Nx.all(y_pred > 0) do - Nx.u8(1) + if nx_to_bool(Nx.greater(y_pred, 0)) do + :ok else - Nx.u8(2) + raise message <> "strictly positive y_pred." end power == 0 -> - Nx.u8(1) + :ok power >= 1 and power < 2 -> - if Nx.all(y_true >= 0) and Nx.all(y_pred > 0) do - Nx.u8(1) + if nx_to_bool(Nx.greater_equal(y_true, 0)) and nx_to_bool(Nx.greater(y_pred, 0)) do + :ok else - Nx.u8(4) + raise message <> "non-negative y_true and strictly positive y_pred." end power >= 2 -> - if Nx.all(y_true > 0) and Nx.all(y_pred > 0) do - Nx.u8(1) + if nx_to_bool(Nx.greater(y_true, 0)) and nx_to_bool(Nx.greater(y_pred, 0)) do + :ok else - Nx.u8(5) + raise message <> "strictly positive y_true and strictly positive y_pred." end true -> - Nx.u8(100) + raise "Something went wrong, branch should never appear." end end - @doc ~S""" + defp check_tweedie_deviance_power(y_true, y_pred, power) when is_tensor(power) do + check_tweedie_deviance_power(y_true, y_pred, Nx.to_number(power)) + end + + defp check_tweedie_deviance_power(y_true, y_pred, :neg_infinity) do + # Same math function check + check_tweedie_deviance_power(y_true, y_pred, -1) + end + + defp check_tweedie_deviance_power(y_true, y_pred, :infinity) do + # Same math function check + check_tweedie_deviance_power(y_true, y_pred, 2) + end + + defp check_tweedie_deviance_power(_y_true, _y_pred, :nan) do + raise "NaN is not supported." + end + + defp nx_to_bool(tensor) when is_tensor(tensor) do + tensor + |> Nx.all() + |> Nx.to_number() + |> case do + 0 -> false + 1 -> true + end + end + + @doc """ Calculates the mean Poisson deviance of predictions with respect to targets. + ## Options + + #{NimbleOptions.docs(@mean_poisson_deviance_schema)} + ## Examples - iex> y_true = Nx.tensor([[0.0, 2.0], [0.5, 0.0]]) - iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 1.0]]) - iex> Scholar.Metrics.Regression.mean_square_error(y_true, y_pred) + iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) + iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32) + iex> Scholar.Metrics.Regression.mean_poisson_deviance(y_true, y_pred) #Nx.Tensor< f32 - 0.8125 + 0.18411168456077576 > """ - defn mean_poisson_deviance(y_true, y_pred) do - mean_tweedie_deviance_n(y_true, y_pred, 1) + deftransform mean_poisson_deviance(y_true, y_pred, opts \\ []) do + mean_tweedie_deviance(y_true, y_pred, 1, opts) end - @doc ~S""" + @doc """ Calculates the mean Gamma deviance of predictions with respect to targets. + ## Options + + #{NimbleOptions.docs(@mean_gamma_deviance_schema)} + ## Examples - iex> y_true = Nx.tensor([[1.0, 2.0], [0.5, 2.0]]) - iex> y_pred = Nx.tensor([[1.0, 1.0], [1.0, 1.0]]) - iex> Scholar.Metrics.Regression.mean_square_error(y_true, y_pred) + iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) + iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32) + iex> Scholar.Metrics.Regression.mean_gamma_deviance(y_true, y_pred) #Nx.Tensor< f32 - 0.5625 + 0.115888312458992 > """ - defn mean_gamma_deviance(y_true, y_pred) do - mean_tweedie_deviance_n(y_true, y_pred, 2) + deftransform mean_gamma_deviance(y_true, y_pred, opts \\ []) do + mean_tweedie_deviance(y_true, y_pred, 2, opts) end @doc """ diff --git a/test/scholar/metrics/regression_test.exs b/test/scholar/metrics/regression_test.exs index d5711ce3..ee4634f0 100644 --- a/test/scholar/metrics/regression_test.exs +++ b/test/scholar/metrics/regression_test.exs @@ -5,53 +5,53 @@ defmodule Scholar.Metrics.RegressionTest do doctest Regression describe "mean_tweedie_deviance" do - test "raise when y_pred <= 0 and power < 0" do + test "raise when y_pred <= 0 and power < 0 with check_tensors: true" do power = -1 y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> - Regression.mean_tweedie_deviance(y_true, y_pred, power) + Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true) end end - test "raise when y_pred <= 0 and 1 <= power < 2" do + test "raise when y_pred <= 0 and 1 <= power < 2 with check_tensors: true" do power = 1 y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> - Regression.mean_tweedie_deviance(y_true, y_pred, power) + Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true) end end - test "raise when y_pred <= 0 and power >= 2" do + test "raise when y_pred <= 0 and power >= 2 with check_tensors: true" do power = 2 y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> - Regression.mean_tweedie_deviance(y_true, y_pred, power) + Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true) end end - test "raise when y_true < 0 and 1 <= power < 2" do + test "raise when y_true < 0 and 1 <= power < 2 with check_tensors: true" do power = 1 y_true = Nx.tensor([-1, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32) y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32) assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> - Regression.mean_tweedie_deviance(y_true, y_pred, power) + Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true) end end - test "raise when y_true <= 0 and power >= 2" do + test "raise when y_true <= 0 and power >= 2 with check_tensors: true" do power = 2 y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32) y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32) assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> - Regression.mean_tweedie_deviance(y_true, y_pred, power) + Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true) end end end From 41792b1f28a538dbeea956b7f146a47a18be990a Mon Sep 17 00:00:00 2001 From: 0urobor0s <0urobor0s@users.noreply.github.com> Date: Wed, 18 Oct 2023 18:41:49 +0100 Subject: [PATCH 3/5] Move checks into mean_tweedie_deviance! --- lib/scholar/metrics/regression.ex | 122 ++++++++--------------- test/scholar/metrics/regression_test.exs | 32 +++--- 2 files changed, 58 insertions(+), 96 deletions(-) diff --git a/lib/scholar/metrics/regression.ex b/lib/scholar/metrics/regression.ex index 37ebec5b..7e5f9e72 100644 --- a/lib/scholar/metrics/regression.ex +++ b/lib/scholar/metrics/regression.ex @@ -11,7 +11,6 @@ defmodule Scholar.Metrics.Regression do any supported `Nx` compiler. """ - import Nx, only: [is_tensor: 1] import Nx.Defn, except: [assert_shape: 2, assert_shape_pattern: 2] import Scholar.Shared import Scholar.Metrics.Distance @@ -27,24 +26,7 @@ defmodule Scholar.Metrics.Regression do ] ] - mean_tweedie_deviance_schema = [ - check_tensors: [ - type: :boolean, - default: false, - doc: """ - Flag indicating if tensor inputs should be checked to conform with the - necessary properties for the given power value. - """ - ] - ] - - mean_poisson_deviance_schema = mean_tweedie_deviance_schema - mean_gamma_deviance_schema = mean_tweedie_deviance_schema - @r2_schema NimbleOptions.new!(r2_schema) - @mean_tweedie_deviance_schema NimbleOptions.new!(mean_tweedie_deviance_schema) - @mean_poisson_deviance_schema NimbleOptions.new!(mean_poisson_deviance_schema) - @mean_gamma_deviance_schema NimbleOptions.new!(mean_gamma_deviance_schema) # Standard Metrics @@ -165,10 +147,6 @@ defmodule Scholar.Metrics.Regression do \end{cases}$$ '''} - ## Options - - #{NimbleOptions.docs(@mean_tweedie_deviance_schema)} - ## Examples iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) @@ -179,11 +157,33 @@ defmodule Scholar.Metrics.Regression do 0.18411168456077576 > """ - deftransform mean_tweedie_deviance(y_true, y_pred, power, opts \\ []) do - opts = NimbleOptions.validate!(opts, @mean_tweedie_deviance_schema) + deftransform mean_tweedie_deviance(y_true, y_pred, power) do + mean_tweedie_deviance_n(y_true, y_pred, power) + end + + @doc """ + Similar to `mean_tweedie_deviance/3` but raises `RuntimeError` if the + inputs cannot be used with the given power argument. + + ## Examples - if opts[:check_tensors] do - check_tweedie_deviance_power(y_true, y_pred, power) + iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) + iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32) + iex> Scholar.Metrics.Regression.mean_tweedie_deviance!(y_true, y_pred, 1) + #Nx.Tensor< + f32 + 0.18411168456077576 + > + """ + def mean_tweedie_deviance!(y_true, y_pred, power) do + message = "mean Tweedie deviance with power=#{power} can only be used on " + + case check_tweedie_deviance_power(y_true, y_pred, power) |> Nx.to_number() do + 1 -> :ok + 2 -> raise message <> "strictly positive y_pred" + 4 -> raise message <> "non-negative y_true and strictly positive y_pred" + 5 -> raise message <> "strictly positive y_true and strictly positive y_pred" + 100 -> raise "something went wrong, branch should never appear" end mean_tweedie_deviance_n(y_true, y_pred, power) @@ -227,64 +227,34 @@ defmodule Scholar.Metrics.Regression do Nx.mean(deviance) end - defp check_tweedie_deviance_power(y_true, y_pred, power) when is_number(power) do - message = "Mean Tweedie deviance with power=#{power} can only be used on " - + defnp check_tweedie_deviance_power(y_true, y_pred, power) do cond do power < 0 -> - if nx_to_bool(Nx.greater(y_pred, 0)) do - :ok + if Nx.all(y_pred > 0) do + Nx.u8(1) else - raise message <> "strictly positive y_pred." + Nx.u8(2) end power == 0 -> - :ok + Nx.u8(1) power >= 1 and power < 2 -> - if nx_to_bool(Nx.greater_equal(y_true, 0)) and nx_to_bool(Nx.greater(y_pred, 0)) do - :ok + if Nx.all(y_true >= 0) and Nx.all(y_pred > 0) do + Nx.u8(1) else - raise message <> "non-negative y_true and strictly positive y_pred." + Nx.u8(4) end power >= 2 -> - if nx_to_bool(Nx.greater(y_true, 0)) and nx_to_bool(Nx.greater(y_pred, 0)) do - :ok + if Nx.all(y_true > 0) and Nx.all(y_pred > 0) do + Nx.u8(1) else - raise message <> "strictly positive y_true and strictly positive y_pred." + Nx.u8(5) end true -> - raise "Something went wrong, branch should never appear." - end - end - - defp check_tweedie_deviance_power(y_true, y_pred, power) when is_tensor(power) do - check_tweedie_deviance_power(y_true, y_pred, Nx.to_number(power)) - end - - defp check_tweedie_deviance_power(y_true, y_pred, :neg_infinity) do - # Same math function check - check_tweedie_deviance_power(y_true, y_pred, -1) - end - - defp check_tweedie_deviance_power(y_true, y_pred, :infinity) do - # Same math function check - check_tweedie_deviance_power(y_true, y_pred, 2) - end - - defp check_tweedie_deviance_power(_y_true, _y_pred, :nan) do - raise "NaN is not supported." - end - - defp nx_to_bool(tensor) when is_tensor(tensor) do - tensor - |> Nx.all() - |> Nx.to_number() - |> case do - 0 -> false - 1 -> true + Nx.u8(100) end end @@ -292,10 +262,6 @@ defmodule Scholar.Metrics.Regression do Calculates the mean Poisson deviance of predictions with respect to targets. - ## Options - - #{NimbleOptions.docs(@mean_poisson_deviance_schema)} - ## Examples iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) @@ -306,18 +272,14 @@ defmodule Scholar.Metrics.Regression do 0.18411168456077576 > """ - deftransform mean_poisson_deviance(y_true, y_pred, opts \\ []) do - mean_tweedie_deviance(y_true, y_pred, 1, opts) + defn mean_poisson_deviance(y_true, y_pred) do + mean_tweedie_deviance_n(y_true, y_pred, 1) end @doc """ Calculates the mean Gamma deviance of predictions with respect to targets. - ## Options - - #{NimbleOptions.docs(@mean_gamma_deviance_schema)} - ## Examples iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) @@ -328,8 +290,8 @@ defmodule Scholar.Metrics.Regression do 0.115888312458992 > """ - deftransform mean_gamma_deviance(y_true, y_pred, opts \\ []) do - mean_tweedie_deviance(y_true, y_pred, 2, opts) + defn mean_gamma_deviance(y_true, y_pred) do + mean_tweedie_deviance_n(y_true, y_pred, 2) end @doc """ diff --git a/test/scholar/metrics/regression_test.exs b/test/scholar/metrics/regression_test.exs index ee4634f0..116cf642 100644 --- a/test/scholar/metrics/regression_test.exs +++ b/test/scholar/metrics/regression_test.exs @@ -4,54 +4,54 @@ defmodule Scholar.Metrics.RegressionTest do alias Scholar.Metrics.Regression doctest Regression - describe "mean_tweedie_deviance" do - test "raise when y_pred <= 0 and power < 0 with check_tensors: true" do + describe "mean_tweedie_deviance!/3" do + test "raise when y_pred <= 0 and power < 0" do power = -1 y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) - assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> - Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true) + assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance!(y_true, y_pred, power) end end - test "raise when y_pred <= 0 and 1 <= power < 2 with check_tensors: true" do + test "raise when y_pred <= 0 and 1 <= power < 2" do power = 1 y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) - assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> - Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true) + assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance!(y_true, y_pred, power) end end - test "raise when y_pred <= 0 and power >= 2 with check_tensors: true" do + test "raise when y_pred <= 0 and power >= 2" do power = 2 y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :u32) y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0], type: :u32) - assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> - Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true) + assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance!(y_true, y_pred, power) end end - test "raise when y_true < 0 and 1 <= power < 2 with check_tensors: true" do + test "raise when y_true < 0 and 1 <= power < 2" do power = 1 y_true = Nx.tensor([-1, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32) y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32) - assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> - Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true) + assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance!(y_true, y_pred, power) end end - test "raise when y_true <= 0 and power >= 2 with check_tensors: true" do + test "raise when y_true <= 0 and power >= 2" do power = 2 y_true = Nx.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1], type: :s32) y_pred = Nx.tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], type: :s32) - assert_raise RuntimeError, ~r/Mean Tweedie deviance/, fn -> - Regression.mean_tweedie_deviance(y_true, y_pred, power, check_tensors: true) + assert_raise RuntimeError, ~r/mean Tweedie deviance/, fn -> + Regression.mean_tweedie_deviance!(y_true, y_pred, power) end end end From d216d01f682b217401737a12f8ca8265236573d3 Mon Sep 17 00:00:00 2001 From: 0urobor0s <0urobor0s@users.noreply.github.com> Date: Wed, 18 Oct 2023 18:43:11 +0100 Subject: [PATCH 4/5] deftransform -> defn --- lib/scholar/metrics/regression.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/scholar/metrics/regression.ex b/lib/scholar/metrics/regression.ex index 7e5f9e72..7e15c654 100644 --- a/lib/scholar/metrics/regression.ex +++ b/lib/scholar/metrics/regression.ex @@ -157,7 +157,7 @@ defmodule Scholar.Metrics.Regression do 0.18411168456077576 > """ - deftransform mean_tweedie_deviance(y_true, y_pred, power) do + defn mean_tweedie_deviance(y_true, y_pred, power) do mean_tweedie_deviance_n(y_true, y_pred, power) end From 96e31482cc3274233759348788907bd41eecdc12 Mon Sep 17 00:00:00 2001 From: 0urobor0s <0urobor0s@users.noreply.github.com> Date: Thu, 19 Oct 2023 20:07:52 +0100 Subject: [PATCH 5/5] Improve docs --- lib/scholar/metrics/regression.ex | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lib/scholar/metrics/regression.ex b/lib/scholar/metrics/regression.ex index 7e15c654..80050a4d 100644 --- a/lib/scholar/metrics/regression.ex +++ b/lib/scholar/metrics/regression.ex @@ -165,6 +165,8 @@ defmodule Scholar.Metrics.Regression do Similar to `mean_tweedie_deviance/3` but raises `RuntimeError` if the inputs cannot be used with the given power argument. + Note: This function cannot be used in `defn`. + ## Examples iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32)