From 09d500af2bb731776d96978f6285906f6a4d8f06 Mon Sep 17 00:00:00 2001 From: JoaquinIglesiasTurina Date: Thu, 16 May 2024 22:51:34 +0200 Subject: [PATCH] Unify weight handling and refactor linear models' helper functions (#267) --- .../linear/bayesian_ridge_regression.ex | 51 ++------------- lib/scholar/linear/linear_helpers.ex | 62 +++++++++++++++++++ lib/scholar/linear/linear_regression.ex | 45 ++------------ lib/scholar/linear/polynomial_regression.ex | 2 +- lib/scholar/linear/ridge_regression.ex | 51 ++------------- .../linear/bayesian_ridge_regression_test.exs | 2 +- 6 files changed, 81 insertions(+), 132 deletions(-) create mode 100644 lib/scholar/linear/linear_helpers.ex diff --git a/lib/scholar/linear/bayesian_ridge_regression.ex b/lib/scholar/linear/bayesian_ridge_regression.ex index 22ef0e4d..8af3b8fe 100644 --- a/lib/scholar/linear/bayesian_ridge_regression.ex +++ b/lib/scholar/linear/bayesian_ridge_regression.ex @@ -62,6 +62,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do require Nx import Nx.Defn import Scholar.Shared + alias Scholar.Linear.LinearHelpers @derive {Nx.Container, containers: [ @@ -95,13 +96,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do """ ], sample_weights: [ - type: - {:or, - [ - {:custom, Scholar.Options, :non_negative_number, []}, - {:list, {:custom, Scholar.Options, :non_negative_number, []}}, - {:custom, Scholar.Options, :weights, []} - ]}, + type: {:custom, Scholar.Options, :weights, []}, doc: """ The weights for each observation. If not provided, all observations are assigned equal weight. @@ -237,13 +232,9 @@ defmodule Scholar.Linear.BayesianRidgeRegression do ] ++ opts - {sample_weights, opts} = Keyword.pop(opts, :sample_weights, 1.0) x_type = to_float_type(x) - sample_weights = - if Nx.is_tensor(sample_weights), - do: Nx.as_type(sample_weights, x_type), - else: Nx.tensor(sample_weights, type: x_type) + sample_weights = LinearHelpers.build_sample_weights(x, opts) # handle vector types # handle default alpha value, add eps to avoid division by 0 @@ -288,7 +279,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do {x_offset, y_offset} = if opts[:fit_intercept?] do - preprocess_data(x, y, sample_weights, opts) + LinearHelpers.preprocess_data(x, y, sample_weights, opts) else x_offset_shape = Nx.axis_size(x, 1) y_reshaped = if Nx.rank(y) > 1, do: y, else: Nx.reshape(y, {:auto, 1}) @@ -302,7 +293,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do {x, y} = if opts[:sample_weights_flag] do - rescale(x, y, sample_weights) + LinearHelpers.rescale(x, y, sample_weights) else {x, y} end @@ -360,7 +351,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do {x, y, xt_y, u, s, vh, eigenvals, alpha_1, alpha_2, lambda_1, lambda_2, iterations}} end - intercept = set_intercept(coef, x_offset, y_offset, opts[:fit_intercept?]) + intercept = LinearHelpers.set_intercept(coef, x_offset, y_offset, opts[:fit_intercept?]) scaled_sigma = Nx.dot(vh, [0], vh / Nx.new_axis(eigenvals + lambda / alpha, -1), [0]) sigma = scaled_sigma / alpha {coef, intercept, alpha, lambda, iter, has_converged, scores, sigma} @@ -449,34 +440,4 @@ defmodule Scholar.Linear.BayesianRidgeRegression do end defnp predict_n(coeff, intercept, x), do: Nx.dot(x, [-1], coeff, [-1]) + intercept - - # Implements sample weighting by rescaling inputs and - # targets by sqrt(sample_weight). - defnp rescale(x, y, sample_weights) do - factor = Nx.sqrt(sample_weights) - - x_scaled = - case Nx.shape(factor) do - {} -> factor * x - _ -> Nx.new_axis(factor, 1) * x - end - - {x_scaled, factor * y} - end - - defnp set_intercept(coeff, x_offset, y_offset, fit_intercept?) do - if fit_intercept? do - y_offset - Nx.dot(x_offset, coeff) - else - Nx.tensor(0.0, type: Nx.type(coeff)) - end - end - - defnp preprocess_data(x, y, sample_weights, opts) do - if opts[:sample_weights_flag], - do: - {Nx.weighted_mean(x, sample_weights, axes: [0]), - Nx.weighted_mean(y, sample_weights, axes: [0])}, - else: {Nx.mean(x, axes: [0]), Nx.mean(y, axes: [0])} - end end diff --git a/lib/scholar/linear/linear_helpers.ex b/lib/scholar/linear/linear_helpers.ex new file mode 100644 index 00000000..cbe60e53 --- /dev/null +++ b/lib/scholar/linear/linear_helpers.ex @@ -0,0 +1,62 @@ +defmodule Scholar.Linear.LinearHelpers do + require Nx + import Nx.Defn + import Scholar.Shared + + @moduledoc false + + @doc false + def build_sample_weights(x, opts) do + x_type = to_float_type(x) + {num_samples, _} = Nx.shape(x) + default_sample_weights = Nx.broadcast(Nx.as_type(1.0, x_type), {num_samples}) + {sample_weights, _} = Keyword.pop(opts, :sample_weights, default_sample_weights) + + # this is required for ridge regression + sample_weights = + if Nx.is_tensor(sample_weights), + do: Nx.as_type(sample_weights, x_type), + else: Nx.tensor(sample_weights, type: x_type) + + sample_weights + end + + @doc false + defn preprocess_data(x, y, sample_weights, opts) do + if opts[:sample_weights_flag], + do: + {Nx.weighted_mean(x, sample_weights, axes: [0]), + Nx.weighted_mean(y, sample_weights, axes: [0])}, + else: {Nx.mean(x, axes: [0]), Nx.mean(y, axes: [0])} + end + + @doc false + defn set_intercept(coeff, x_offset, y_offset, fit_intercept?) do + if fit_intercept? do + y_offset - Nx.dot(coeff, x_offset) + else + Nx.tensor(0.0, type: Nx.type(coeff)) + end + end + + # Implements sample weighting by rescaling inputs and + # targets by sqrt(sample_weight). + @doc false + defn rescale(x, y, sample_weights) do + factor = Nx.sqrt(sample_weights) + + x_scaled = + case Nx.shape(factor) do + {} -> factor * x + _ -> x * Nx.new_axis(factor, -1) + end + + y_scaled = + case Nx.rank(y) do + 1 -> factor * y + _ -> y * Nx.new_axis(factor, -1) + end + + {x_scaled, y_scaled} + end +end diff --git a/lib/scholar/linear/linear_regression.ex b/lib/scholar/linear/linear_regression.ex index e5d96512..59885ba3 100644 --- a/lib/scholar/linear/linear_regression.ex +++ b/lib/scholar/linear/linear_regression.ex @@ -8,6 +8,7 @@ defmodule Scholar.Linear.LinearRegression do require Nx import Nx.Defn import Scholar.Shared + alias Scholar.Linear.LinearHelpers @derive {Nx.Container, containers: [:coefficients, :intercept]} defstruct [:coefficients, :intercept] @@ -75,13 +76,7 @@ defmodule Scholar.Linear.LinearRegression do ] ++ opts - {sample_weights, opts} = Keyword.pop(opts, :sample_weights, 1.0) - x_type = to_float_type(x) - - sample_weights = - if Nx.is_tensor(sample_weights), - do: Nx.as_type(sample_weights, x_type), - else: Nx.tensor(sample_weights, type: x_type) + sample_weights = LinearHelpers.build_sample_weights(x, opts) fit_n(x, y, sample_weights, opts) end @@ -92,7 +87,7 @@ defmodule Scholar.Linear.LinearRegression do {a_offset, b_offset} = if opts[:fit_intercept?] do - preprocess_data(a, b, sample_weights, opts) + LinearHelpers.preprocess_data(a, b, sample_weights, opts) else a_offset_shape = Nx.axis_size(a, 1) b_reshaped = if Nx.rank(b) > 1, do: b, else: Nx.reshape(b, {:auto, 1}) @@ -106,7 +101,7 @@ defmodule Scholar.Linear.LinearRegression do {a, b} = if opts[:sample_weights_flag] do - rescale(a, b, sample_weights) + LinearHelpers.rescale(a, b, sample_weights) else {a, b} end @@ -132,42 +127,12 @@ defmodule Scholar.Linear.LinearRegression do Nx.dot(x, coeff) + intercept end - # Implements sample weighting by rescaling inputs and - # targets by sqrt(sample_weight). - defnp rescale(x, y, sample_weights) do - case Nx.shape(sample_weights) do - {} = scalar -> - scalar = Nx.sqrt(scalar) - {scalar * x, scalar * y} - - _ -> - scale = sample_weights |> Nx.sqrt() |> Nx.make_diagonal() - {Nx.dot(scale, x), Nx.dot(scale, y)} - end - end - # Implements ordinary least-squares by estimating the # solution A to the equation A.X = b. defnp lstsq(a, b, a_offset, b_offset, fit_intercept?) do pinv = Nx.LinAlg.pinv(a) coeff = Nx.dot(b, [0], pinv, [1]) - intercept = set_intercept(coeff, a_offset, b_offset, fit_intercept?) + intercept = LinearHelpers.set_intercept(coeff, a_offset, b_offset, fit_intercept?) {coeff, intercept} end - - defnp set_intercept(coeff, x_offset, y_offset, fit_intercept?) do - if fit_intercept? do - y_offset - Nx.dot(coeff, x_offset) - else - Nx.tensor(0.0, type: Nx.type(coeff)) - end - end - - defnp preprocess_data(x, y, sample_weights, opts) do - if opts[:sample_weights_flag], - do: - {Nx.weighted_mean(x, sample_weights, axes: [0]), - Nx.weighted_mean(y, sample_weights, axes: [0])}, - else: {Nx.mean(x, axes: [0]), Nx.mean(y, axes: [0])} - end end diff --git a/lib/scholar/linear/polynomial_regression.ex b/lib/scholar/linear/polynomial_regression.ex index e15316a8..76246ad2 100644 --- a/lib/scholar/linear/polynomial_regression.ex +++ b/lib/scholar/linear/polynomial_regression.ex @@ -12,7 +12,7 @@ defmodule Scholar.Linear.PolynomialRegression do opts = [ sample_weights: [ - type: {:list, {:custom, Scholar.Options, :positive_number, []}}, + type: {:custom, Scholar.Options, :weights, []}, doc: """ The weights for each observation. If not provided, all observations are assigned equal weight. diff --git a/lib/scholar/linear/ridge_regression.ex b/lib/scholar/linear/ridge_regression.ex index f63f9893..70bed1c6 100644 --- a/lib/scholar/linear/ridge_regression.ex +++ b/lib/scholar/linear/ridge_regression.ex @@ -22,19 +22,14 @@ defmodule Scholar.Linear.RidgeRegression do require Nx import Nx.Defn import Scholar.Shared + alias Scholar.Linear.LinearHelpers @derive {Nx.Container, containers: [:coefficients, :intercept]} defstruct [:coefficients, :intercept] opts = [ sample_weights: [ - type: - {:or, - [ - {:custom, Scholar.Options, :non_negative_number, []}, - {:list, {:custom, Scholar.Options, :non_negative_number, []}}, - {:custom, Scholar.Options, :weights, []} - ]}, + type: {:custom, Scholar.Options, :weights, []}, doc: """ The weights for each observation. If not provided, all observations are assigned equal weight. @@ -126,13 +121,9 @@ defmodule Scholar.Linear.RidgeRegression do ] ++ opts - {sample_weights, opts} = Keyword.pop(opts, :sample_weights, 1.0) x_type = to_float_type(x) - sample_weights = - if Nx.is_tensor(sample_weights), - do: Nx.as_type(sample_weights, x_type), - else: Nx.tensor(sample_weights, type: x_type) + sample_weights = LinearHelpers.build_sample_weights(x, opts) {alpha, opts} = Keyword.pop!(opts, :alpha) alpha = Nx.tensor(alpha, type: x_type) |> Nx.flatten() @@ -160,7 +151,7 @@ defmodule Scholar.Linear.RidgeRegression do {a_offset, b_offset} = if opts[:fit_intercept?] do - preprocess_data(a, b, sample_weights, opts) + LinearHelpers.preprocess_data(a, b, sample_weights, opts) else a_offset_shape = Nx.axis_size(a, 1) b_reshaped = if Nx.rank(b) > 1, do: b, else: Nx.reshape(b, {:auto, 1}) @@ -175,7 +166,7 @@ defmodule Scholar.Linear.RidgeRegression do {a, b} = if opts[:rescale_flag] do - rescale(a, b, sample_weights) + LinearHelpers.rescale(a, b, sample_weights) else {a, b} end @@ -198,7 +189,7 @@ defmodule Scholar.Linear.RidgeRegression do end coeff = if flatten?, do: Nx.flatten(coeff), else: coeff - intercept = set_intercept(coeff, a_offset, b_offset, opts[:fit_intercept?]) + intercept = LinearHelpers.set_intercept(coeff, a_offset, b_offset, opts[:fit_intercept?]) %__MODULE__{coefficients: coeff, intercept: intercept} end @@ -222,20 +213,6 @@ defmodule Scholar.Linear.RidgeRegression do if original_rank <= 1, do: Nx.squeeze(res, axes: [1]), else: res end - # Implements sample weighting by rescaling inputs and - # targets by sqrt(sample_weight). - defnp rescale(a, b, sample_weights) do - case Nx.shape(sample_weights) do - {} = scalar -> - scalar = Nx.sqrt(scalar) - {scalar * a, scalar * b} - - _ -> - scale = sample_weights |> Nx.sqrt() |> Nx.make_diagonal() - {Nx.dot(scale, a), Nx.dot(scale, b)} - end - end - defnp solve_cholesky_kernel(kernel, b, alpha, sample_weights, opts) do num_samples = Nx.axis_size(kernel, 0) num_targets = Nx.axis_size(b, 1) @@ -325,20 +302,4 @@ defmodule Scholar.Linear.RidgeRegression do d_uty = d * uty Nx.dot(d_uty, [0], vt, [0]) end - - defnp set_intercept(coeff, x_offset, y_offset, fit_intercept?) do - if fit_intercept? do - y_offset - Nx.dot(coeff, x_offset) - else - Nx.tensor(0.0, type: Nx.type(coeff)) - end - end - - defnp preprocess_data(a, b, sample_weights, opts) do - if opts[:sample_weights_flag], - do: - {Nx.weighted_mean(a, sample_weights, axes: [0]), - Nx.weighted_mean(b, sample_weights, axes: [0])}, - else: {Nx.mean(a, axes: [0]), Nx.mean(b, axes: [0])} - end end diff --git a/test/scholar/linear/bayesian_ridge_regression_test.exs b/test/scholar/linear/bayesian_ridge_regression_test.exs index 7f4e36fc..845bbf74 100644 --- a/test/scholar/linear/bayesian_ridge_regression_test.exs +++ b/test/scholar/linear/bayesian_ridge_regression_test.exs @@ -53,7 +53,7 @@ defmodule Scholar.Linear.BayesianRidgeRegressionTest do score = compute_score(x, y, alpha, lambda, alpha_1, alpha_2, lambda_1, lambda_2) brr = - BayesianRidgeRegression.fit(x, y, + BayesianRidgeRegression.fit(x, Nx.flatten(y), alpha_1: alpha_1, alpha_2: alpha_2, lambda_1: lambda_1,