Skip to content

Commit

Permalink
Unify weight handling and refactor linear models' helper functions (e…
Browse files Browse the repository at this point in the history
  • Loading branch information
JoaquinIglesiasTurina authored May 16, 2024
1 parent 74ed5fe commit 09d500a
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 132 deletions.
51 changes: 6 additions & 45 deletions lib/scholar/linear/bayesian_ridge_regression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
require Nx
import Nx.Defn
import Scholar.Shared
alias Scholar.Linear.LinearHelpers

@derive {Nx.Container,
containers: [
Expand Down Expand Up @@ -95,13 +96,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
"""
],
sample_weights: [
type:
{:or,
[
{:custom, Scholar.Options, :non_negative_number, []},
{:list, {:custom, Scholar.Options, :non_negative_number, []}},
{:custom, Scholar.Options, :weights, []}
]},
type: {:custom, Scholar.Options, :weights, []},
doc: """
The weights for each observation. If not provided,
all observations are assigned equal weight.
Expand Down Expand Up @@ -237,13 +232,9 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
] ++
opts

{sample_weights, opts} = Keyword.pop(opts, :sample_weights, 1.0)
x_type = to_float_type(x)

sample_weights =
if Nx.is_tensor(sample_weights),
do: Nx.as_type(sample_weights, x_type),
else: Nx.tensor(sample_weights, type: x_type)
sample_weights = LinearHelpers.build_sample_weights(x, opts)

# handle vector types
# handle default alpha value, add eps to avoid division by 0
Expand Down Expand Up @@ -288,7 +279,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do

{x_offset, y_offset} =
if opts[:fit_intercept?] do
preprocess_data(x, y, sample_weights, opts)
LinearHelpers.preprocess_data(x, y, sample_weights, opts)
else
x_offset_shape = Nx.axis_size(x, 1)
y_reshaped = if Nx.rank(y) > 1, do: y, else: Nx.reshape(y, {:auto, 1})
Expand All @@ -302,7 +293,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do

{x, y} =
if opts[:sample_weights_flag] do
rescale(x, y, sample_weights)
LinearHelpers.rescale(x, y, sample_weights)
else
{x, y}
end
Expand Down Expand Up @@ -360,7 +351,7 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
{x, y, xt_y, u, s, vh, eigenvals, alpha_1, alpha_2, lambda_1, lambda_2, iterations}}
end

intercept = set_intercept(coef, x_offset, y_offset, opts[:fit_intercept?])
intercept = LinearHelpers.set_intercept(coef, x_offset, y_offset, opts[:fit_intercept?])
scaled_sigma = Nx.dot(vh, [0], vh / Nx.new_axis(eigenvals + lambda / alpha, -1), [0])
sigma = scaled_sigma / alpha
{coef, intercept, alpha, lambda, iter, has_converged, scores, sigma}
Expand Down Expand Up @@ -449,34 +440,4 @@ defmodule Scholar.Linear.BayesianRidgeRegression do
end

defnp predict_n(coeff, intercept, x), do: Nx.dot(x, [-1], coeff, [-1]) + intercept

# Implements sample weighting by rescaling inputs and
# targets by sqrt(sample_weight).
defnp rescale(x, y, sample_weights) do
factor = Nx.sqrt(sample_weights)

x_scaled =
case Nx.shape(factor) do
{} -> factor * x
_ -> Nx.new_axis(factor, 1) * x
end

{x_scaled, factor * y}
end

defnp set_intercept(coeff, x_offset, y_offset, fit_intercept?) do
if fit_intercept? do
y_offset - Nx.dot(x_offset, coeff)
else
Nx.tensor(0.0, type: Nx.type(coeff))
end
end

defnp preprocess_data(x, y, sample_weights, opts) do
if opts[:sample_weights_flag],
do:
{Nx.weighted_mean(x, sample_weights, axes: [0]),
Nx.weighted_mean(y, sample_weights, axes: [0])},
else: {Nx.mean(x, axes: [0]), Nx.mean(y, axes: [0])}
end
end
62 changes: 62 additions & 0 deletions lib/scholar/linear/linear_helpers.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
defmodule Scholar.Linear.LinearHelpers do
require Nx
import Nx.Defn
import Scholar.Shared

@moduledoc false

@doc false
def build_sample_weights(x, opts) do
x_type = to_float_type(x)
{num_samples, _} = Nx.shape(x)
default_sample_weights = Nx.broadcast(Nx.as_type(1.0, x_type), {num_samples})
{sample_weights, _} = Keyword.pop(opts, :sample_weights, default_sample_weights)

# this is required for ridge regression
sample_weights =
if Nx.is_tensor(sample_weights),
do: Nx.as_type(sample_weights, x_type),
else: Nx.tensor(sample_weights, type: x_type)

sample_weights
end

@doc false
defn preprocess_data(x, y, sample_weights, opts) do
if opts[:sample_weights_flag],
do:
{Nx.weighted_mean(x, sample_weights, axes: [0]),
Nx.weighted_mean(y, sample_weights, axes: [0])},
else: {Nx.mean(x, axes: [0]), Nx.mean(y, axes: [0])}
end

@doc false
defn set_intercept(coeff, x_offset, y_offset, fit_intercept?) do
if fit_intercept? do
y_offset - Nx.dot(coeff, x_offset)
else
Nx.tensor(0.0, type: Nx.type(coeff))
end
end

# Implements sample weighting by rescaling inputs and
# targets by sqrt(sample_weight).
@doc false
defn rescale(x, y, sample_weights) do
factor = Nx.sqrt(sample_weights)

x_scaled =
case Nx.shape(factor) do
{} -> factor * x
_ -> x * Nx.new_axis(factor, -1)
end

y_scaled =
case Nx.rank(y) do
1 -> factor * y
_ -> y * Nx.new_axis(factor, -1)
end

{x_scaled, y_scaled}
end
end
45 changes: 5 additions & 40 deletions lib/scholar/linear/linear_regression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ defmodule Scholar.Linear.LinearRegression do
require Nx
import Nx.Defn
import Scholar.Shared
alias Scholar.Linear.LinearHelpers

@derive {Nx.Container, containers: [:coefficients, :intercept]}
defstruct [:coefficients, :intercept]
Expand Down Expand Up @@ -75,13 +76,7 @@ defmodule Scholar.Linear.LinearRegression do
] ++
opts

{sample_weights, opts} = Keyword.pop(opts, :sample_weights, 1.0)
x_type = to_float_type(x)

sample_weights =
if Nx.is_tensor(sample_weights),
do: Nx.as_type(sample_weights, x_type),
else: Nx.tensor(sample_weights, type: x_type)
sample_weights = LinearHelpers.build_sample_weights(x, opts)

fit_n(x, y, sample_weights, opts)
end
Expand All @@ -92,7 +87,7 @@ defmodule Scholar.Linear.LinearRegression do

{a_offset, b_offset} =
if opts[:fit_intercept?] do
preprocess_data(a, b, sample_weights, opts)
LinearHelpers.preprocess_data(a, b, sample_weights, opts)
else
a_offset_shape = Nx.axis_size(a, 1)
b_reshaped = if Nx.rank(b) > 1, do: b, else: Nx.reshape(b, {:auto, 1})
Expand All @@ -106,7 +101,7 @@ defmodule Scholar.Linear.LinearRegression do

{a, b} =
if opts[:sample_weights_flag] do
rescale(a, b, sample_weights)
LinearHelpers.rescale(a, b, sample_weights)
else
{a, b}
end
Expand All @@ -132,42 +127,12 @@ defmodule Scholar.Linear.LinearRegression do
Nx.dot(x, coeff) + intercept
end

# Implements sample weighting by rescaling inputs and
# targets by sqrt(sample_weight).
defnp rescale(x, y, sample_weights) do
case Nx.shape(sample_weights) do
{} = scalar ->
scalar = Nx.sqrt(scalar)
{scalar * x, scalar * y}

_ ->
scale = sample_weights |> Nx.sqrt() |> Nx.make_diagonal()
{Nx.dot(scale, x), Nx.dot(scale, y)}
end
end

# Implements ordinary least-squares by estimating the
# solution A to the equation A.X = b.
defnp lstsq(a, b, a_offset, b_offset, fit_intercept?) do
pinv = Nx.LinAlg.pinv(a)
coeff = Nx.dot(b, [0], pinv, [1])
intercept = set_intercept(coeff, a_offset, b_offset, fit_intercept?)
intercept = LinearHelpers.set_intercept(coeff, a_offset, b_offset, fit_intercept?)
{coeff, intercept}
end

defnp set_intercept(coeff, x_offset, y_offset, fit_intercept?) do
if fit_intercept? do
y_offset - Nx.dot(coeff, x_offset)
else
Nx.tensor(0.0, type: Nx.type(coeff))
end
end

defnp preprocess_data(x, y, sample_weights, opts) do
if opts[:sample_weights_flag],
do:
{Nx.weighted_mean(x, sample_weights, axes: [0]),
Nx.weighted_mean(y, sample_weights, axes: [0])},
else: {Nx.mean(x, axes: [0]), Nx.mean(y, axes: [0])}
end
end
2 changes: 1 addition & 1 deletion lib/scholar/linear/polynomial_regression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ defmodule Scholar.Linear.PolynomialRegression do

opts = [
sample_weights: [
type: {:list, {:custom, Scholar.Options, :positive_number, []}},
type: {:custom, Scholar.Options, :weights, []},
doc: """
The weights for each observation. If not provided,
all observations are assigned equal weight.
Expand Down
51 changes: 6 additions & 45 deletions lib/scholar/linear/ridge_regression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,14 @@ defmodule Scholar.Linear.RidgeRegression do
require Nx
import Nx.Defn
import Scholar.Shared
alias Scholar.Linear.LinearHelpers

@derive {Nx.Container, containers: [:coefficients, :intercept]}
defstruct [:coefficients, :intercept]

opts = [
sample_weights: [
type:
{:or,
[
{:custom, Scholar.Options, :non_negative_number, []},
{:list, {:custom, Scholar.Options, :non_negative_number, []}},
{:custom, Scholar.Options, :weights, []}
]},
type: {:custom, Scholar.Options, :weights, []},
doc: """
The weights for each observation. If not provided,
all observations are assigned equal weight.
Expand Down Expand Up @@ -126,13 +121,9 @@ defmodule Scholar.Linear.RidgeRegression do
] ++
opts

{sample_weights, opts} = Keyword.pop(opts, :sample_weights, 1.0)
x_type = to_float_type(x)

sample_weights =
if Nx.is_tensor(sample_weights),
do: Nx.as_type(sample_weights, x_type),
else: Nx.tensor(sample_weights, type: x_type)
sample_weights = LinearHelpers.build_sample_weights(x, opts)

{alpha, opts} = Keyword.pop!(opts, :alpha)
alpha = Nx.tensor(alpha, type: x_type) |> Nx.flatten()
Expand Down Expand Up @@ -160,7 +151,7 @@ defmodule Scholar.Linear.RidgeRegression do

{a_offset, b_offset} =
if opts[:fit_intercept?] do
preprocess_data(a, b, sample_weights, opts)
LinearHelpers.preprocess_data(a, b, sample_weights, opts)
else
a_offset_shape = Nx.axis_size(a, 1)
b_reshaped = if Nx.rank(b) > 1, do: b, else: Nx.reshape(b, {:auto, 1})
Expand All @@ -175,7 +166,7 @@ defmodule Scholar.Linear.RidgeRegression do

{a, b} =
if opts[:rescale_flag] do
rescale(a, b, sample_weights)
LinearHelpers.rescale(a, b, sample_weights)
else
{a, b}
end
Expand All @@ -198,7 +189,7 @@ defmodule Scholar.Linear.RidgeRegression do
end

coeff = if flatten?, do: Nx.flatten(coeff), else: coeff
intercept = set_intercept(coeff, a_offset, b_offset, opts[:fit_intercept?])
intercept = LinearHelpers.set_intercept(coeff, a_offset, b_offset, opts[:fit_intercept?])
%__MODULE__{coefficients: coeff, intercept: intercept}
end

Expand All @@ -222,20 +213,6 @@ defmodule Scholar.Linear.RidgeRegression do
if original_rank <= 1, do: Nx.squeeze(res, axes: [1]), else: res
end

# Implements sample weighting by rescaling inputs and
# targets by sqrt(sample_weight).
defnp rescale(a, b, sample_weights) do
case Nx.shape(sample_weights) do
{} = scalar ->
scalar = Nx.sqrt(scalar)
{scalar * a, scalar * b}

_ ->
scale = sample_weights |> Nx.sqrt() |> Nx.make_diagonal()
{Nx.dot(scale, a), Nx.dot(scale, b)}
end
end

defnp solve_cholesky_kernel(kernel, b, alpha, sample_weights, opts) do
num_samples = Nx.axis_size(kernel, 0)
num_targets = Nx.axis_size(b, 1)
Expand Down Expand Up @@ -325,20 +302,4 @@ defmodule Scholar.Linear.RidgeRegression do
d_uty = d * uty
Nx.dot(d_uty, [0], vt, [0])
end

defnp set_intercept(coeff, x_offset, y_offset, fit_intercept?) do
if fit_intercept? do
y_offset - Nx.dot(coeff, x_offset)
else
Nx.tensor(0.0, type: Nx.type(coeff))
end
end

defnp preprocess_data(a, b, sample_weights, opts) do
if opts[:sample_weights_flag],
do:
{Nx.weighted_mean(a, sample_weights, axes: [0]),
Nx.weighted_mean(b, sample_weights, axes: [0])},
else: {Nx.mean(a, axes: [0]), Nx.mean(b, axes: [0])}
end
end
2 changes: 1 addition & 1 deletion test/scholar/linear/bayesian_ridge_regression_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ defmodule Scholar.Linear.BayesianRidgeRegressionTest do
score = compute_score(x, y, alpha, lambda, alpha_1, alpha_2, lambda_1, lambda_2)

brr =
BayesianRidgeRegression.fit(x, y,
BayesianRidgeRegression.fit(x, Nx.flatten(y),
alpha_1: alpha_1,
alpha_2: alpha_2,
lambda_1: lambda_1,
Expand Down

0 comments on commit 09d500a

Please sign in to comment.