From 78365a26f7f2ff037b12d2e4142cde8f0f92d894 Mon Sep 17 00:00:00 2001 From: 0urobor0s <0urobor0s@users.noreply.github.com> Date: Thu, 26 Oct 2023 18:58:25 +0100 Subject: [PATCH] Add d2_tweedie_score --- lib/scholar/metrics/regression.ex | 40 ++++++++++++++++++++++++ test/scholar/metrics/regression_test.exs | 11 +++++++ 2 files changed, 51 insertions(+) diff --git a/lib/scholar/metrics/regression.ex b/lib/scholar/metrics/regression.ex index 80050a4d..aa2385ad 100644 --- a/lib/scholar/metrics/regression.ex +++ b/lib/scholar/metrics/regression.ex @@ -455,6 +455,46 @@ defmodule Scholar.Metrics.Regression do end end + @doc """ + $D^2$ regression score function, fraction of Tweedie + deviance explained. + + Best possible score is 1.0, lower values are worse and it + can also be negative. + + Since it uses the mean Tweedie deviance, it also includes + the Gaussian, Poisson, Gamma and inverse-Gaussian + distribution families as special cases. + + ## Examples + + iex> y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) + iex> y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32) + iex> Scholar.Metrics.Regression.d2_tweedie_score(y_true, y_pred, 1) + #Nx.Tensor< + f32 + 0.32202935218811035 + > + """ + defn d2_tweedie_score(y_true, y_pred, power) do + if Nx.size(y_pred) < 2 do + Nx.Constants.nan() + else + d2_tweedie_score_n(y_true, y_pred, power) + end + end + + defnp d2_tweedie_score_n(y_true, y_pred, power) do + y_true = Nx.squeeze(y_true) + y_pred = Nx.squeeze(y_pred) + + numerator = mean_tweedie_deviance_n(y_true, y_pred, power) + y_avg = Nx.mean(y_true) + denominator = mean_tweedie_deviance_n(y_true, y_avg, power) + + 1 - numerator / denominator + end + @doc ~S""" Calculates the maximum residual error. diff --git a/test/scholar/metrics/regression_test.exs b/test/scholar/metrics/regression_test.exs index 116cf642..2522da25 100644 --- a/test/scholar/metrics/regression_test.exs +++ b/test/scholar/metrics/regression_test.exs @@ -55,4 +55,15 @@ defmodule Scholar.Metrics.RegressionTest do end end end + + describe "d2_tweedie_score/3" do + test "equal R^2 when power is 0" do + y_true = Nx.tensor([1, 1, 1, 1, 1, 2, 2, 1, 3, 1], type: :u32) + y_pred = Nx.tensor([2, 2, 1, 1, 2, 2, 2, 1, 3, 1], type: :u32) + d2 = Regression.d2_tweedie_score(y_true, y_pred, 0) + r2 = Regression.r2_score(y_true, y_pred) + + assert Nx.equal(d2, r2) + end + end end