Skip to content

Commit

Permalink
Improve linear interpolation (elixir-nx#190)
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak authored Oct 18, 2023
1 parent b36df2f commit dd87b99
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 14 deletions.
133 changes: 122 additions & 11 deletions lib/scholar/interpolation/linear.ex
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,27 @@ defmodule Scholar.Interpolation.Linear do
$$
"""
import Nx.Defn
import Scholar.Shared

@derive {Nx.Container, containers: [:coefficients, :x]}
defstruct [:coefficients, :x]

@type t :: %Scholar.Interpolation.Linear{}

opts_schema = [
left: [
type: {:or, [:float, :integer]},
doc:
"Value to return for values in `target_x` smaller that the smallest value in training set"
],
right: [
type: {:or, [:float, :integer]},
doc:
"Value to return for values in `target_x` greater that the greatest value in training set"
]
]

@opts_schema NimbleOptions.new!(opts_schema)
@doc """
Fits a linear interpolation of the given `(x, y)` points
Expand All @@ -41,7 +56,7 @@ defmodule Scholar.Interpolation.Linear do
]
),
x: Nx.tensor(
[0, 1]
[0, 1, 2]
)
}
"""
Expand Down Expand Up @@ -79,7 +94,7 @@ defmodule Scholar.Interpolation.Linear do

coefficients = Nx.stack([a, b], axis: 1)

%__MODULE__{coefficients: coefficients, x: x0}
%__MODULE__{coefficients: coefficients, x: x}
end

@doc """
Expand All @@ -97,24 +112,120 @@ defmodule Scholar.Interpolation.Linear do
[2.0, 6.0]
]
)
iex> x = Nx.iota({5})
iex> y = Nx.tensor([2.0, 0.0, 1.0, 3.0, 4.0])
iex> model = Scholar.Interpolation.Linear.fit(x, y)
iex> target_x = Nx.tensor([-2, -1, 1.25, 3, 3.25, 5.0])
iex> Scholar.Interpolation.Linear.predict(model, target_x, left: 0.0, right: 10.0)
#Nx.Tensor<
f32[6]
[0.0, 0.0, 0.25, 3.0, 3.25, 10.0]
>
"""
defn predict(%__MODULE__{x: x, coefficients: coefficients} = _model, target_x) do
original_shape = Nx.shape(target_x)
deftransform predict(model, target_x, opts \\ []) do
predict_n(model, target_x, NimbleOptions.validate!(opts, @opts_schema))
end

defnp predict_n(%__MODULE__{x: x, coefficients: coefficients} = _model, target_x, opts) do
shape = Nx.shape(target_x)

target_x = Nx.flatten(target_x)

idx_selector = Nx.new_axis(target_x, 1) >= x
indices = Nx.argsort(target_x)

left_bound = x[0]
right_bound = x[-1]

target_x = Nx.sort(target_x)
res = Nx.broadcast(Nx.tensor(0, type: to_float_type(target_x)), {Nx.axis_size(target_x, 0)})

# while with smaller than left_bound
{{res, i}, _} =
while {{res, i = 0}, {x, coefficients, left_bound, target_x}},
check_cond_left(target_x, i, left_bound) do
val =
case opts[:left] do
nil ->
coefficients[0][0] * Nx.take(target_x, i) + coefficients[0][1]

_ ->
opts[:left]
end

res = Nx.indexed_put(res, Nx.new_axis(i, -1), val)
{{res, i + 1}, {x, coefficients, left_bound, target_x}}
end

{{res, i}, _} =
while {{res, i}, {x, right_bound, coefficients, target_x, j = 0}},
check_cond_right(target_x, i, right_bound) do
{j, _} =
while {j, {i, x, target_x}},
j < Nx.axis_size(x, 0) and Nx.take(x, j) < Nx.take(target_x, i) do
{j + 1, {i, x, target_x}}
end

res =
Nx.indexed_put(
res,
Nx.new_axis(i, -1),
coefficients[Nx.max(j - 1, 0)][0] * Nx.take(target_x, i) +
coefficients[Nx.max(j - 1, 0)][1]
)

i = i + 1

{{res, i}, {x, right_bound, coefficients, target_x, j}}
end

{res, i}

# while with bigger than right_bound

{res, _} =
while {res, {x, coefficients, target_x, i}},
i < Nx.axis_size(target_x, 0) do
val =
case opts[:right] do
nil ->
coefficients[-1][0] * Nx.take(target_x, i) + coefficients[-1][1]

_ ->
opts[:right]
end

res = Nx.indexed_put(res, Nx.new_axis(i, -1), val)
{res, {x, coefficients, target_x, i + 1}}
end

res = Nx.take(res, indices)
Nx.reshape(res, shape)
end

idx_poly = Nx.argmax(idx_selector, axis: 1, tie_break: :high)
defnp check_cond_left(target_x, i, left_bound) do
cond do
i >= Nx.axis_size(target_x, 0) ->
Nx.u8(0)

idx_poly = Nx.select(Nx.all(idx_selector == 0, axes: [1]), 0, idx_poly)
Nx.take(target_x, i) >= left_bound ->
Nx.u8(0)

coef_poly = Nx.take(coefficients, idx_poly)
true ->
Nx.u8(1)
end
end

x_poly = Nx.stack([target_x, Nx.broadcast(1, target_x)], axis: 1)
defnp check_cond_right(target_x, i, right_bound) do
cond do
i >= Nx.axis_size(target_x, 0) ->
Nx.u8(0)

result = Nx.dot(x_poly, [1], [0], coef_poly, [1], [0])
Nx.take(target_x, i) > right_bound ->
Nx.u8(0)

Nx.reshape(result, original_shape)
true ->
Nx.u8(1)
end
end
end
4 changes: 2 additions & 2 deletions lib/scholar/linear/isotonic_regression.ex
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ defmodule Scholar.Linear.IsotonicRegression do
iex> Scholar.Linear.IsotonicRegression.predict(model, to_predict)
#Nx.Tensor<
f32[10]
[1.0, 1.6666667461395264, 2.3333334922790527, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
[1.0, 1.6666667461395264, 2.3333332538604736, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]
>
"""
defn predict(model, x) do
Expand Down Expand Up @@ -259,7 +259,7 @@ defmodule Scholar.Linear.IsotonicRegression do
]
),
x: Nx.tensor(
[1.0, 4.0, 7.0, 9.0, 10.0]
[1.0, 4.0, 7.0, 9.0, 10.0, 11.0]
)
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/scholar/linear/isotonic_regression_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ defmodule Scholar.Linear.IsotonicRegressionTest do
[43.0, 44.0, 46.0, 47.0, 55.0, 56.0] ++
[60.0, 61.0, 66.0, 67.0, 71.0, 72.0] ++
[73.0, 76.0, 77.0, 80.0, 81.0, 87.0] ++
[88.0, 92.0, 93.0, 97.0, 98.0]
[88.0, 92.0, 93.0, 97.0, 98.0, 99.0]
)
)
end
Expand Down

0 comments on commit dd87b99

Please sign in to comment.