Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve linear interpolation #190

Merged
merged 23 commits into from
Oct 18, 2023
Merged
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
77ac659
Update mix.installs
msluszniak Aug 22, 2023
47c288b
Merge branch 'elixir-nx:main' into main
msluszniak Aug 23, 2023
4be4733
Merge branch 'elixir-nx:main' into main
msluszniak Aug 29, 2023
65f7376
Merge branch 'main' of github.com:msluszniak/scholar into main
msluszniak Aug 30, 2023
dd1a29c
Merge branch 'main' of github.com:msluszniak/scholar into main
msluszniak Sep 11, 2023
ba4b000
Merge branch 'main' of github.com:msluszniak/scholar into main
msluszniak Oct 3, 2023
8ca6a32
F-beta score
0urobor0s Oct 10, 2023
8088028
Show possibility of non tensor values as beta
0urobor0s Oct 10, 2023
0b334c8
Make auxiliary functions private
0urobor0s Oct 10, 2023
454d760
Merge commit 'refs/pull/185/head' of github.com:elixir-nx/scholar int…
msluszniak Oct 10, 2023
f6a352a
Add improvements
msluszniak Oct 12, 2023
ef7831b
Bring back type
msluszniak Oct 12, 2023
27b6e96
Merge branch 'main' into improve_linear_interpolation
msluszniak Oct 12, 2023
cff206b
Update lib/scholar/interpolation/linear.ex
josevalim Oct 13, 2023
03f4707
Correct tests
msluszniak Oct 17, 2023
384d415
Merge branch 'improve_linear_interpolation' of github.com:msluszniak/…
msluszniak Oct 17, 2023
bb41862
Change resultant tensor shape
msluszniak Oct 17, 2023
0330d8d
fix
msluszniak Oct 18, 2023
c82d8a1
Add improvements
msluszniak Oct 12, 2023
2faba84
Bring back type
msluszniak Oct 12, 2023
24859a0
Correct tests
msluszniak Oct 17, 2023
f2bb1c6
Change resultant tensor shape
msluszniak Oct 17, 2023
ca0304b
rebase
msluszniak Oct 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 123 additions & 11 deletions lib/scholar/interpolation/linear.ex
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,28 @@ defmodule Scholar.Interpolation.Linear do
$$
"""
import Nx.Defn
import Scholar.Shared

@derive {Nx.Container, containers: [:coefficients, :x]}
defstruct [:coefficients, :x]

@type t :: %Scholar.Interpolation.Linear{}

opts_schema = [
left: [
type: {:or, [:float, :integer]},
doc:
"Value to return for values in `target_x` smaller that the smallest value in training set"
],
right: [
type: {:or, [:float, :integer]},
doc:
"Value to return for values in `target_x` greater that the greatest value in training set"
]
]

@opts_schema NimbleOptions.new!(opts_schema)

josevalim marked this conversation as resolved.
Show resolved Hide resolved
@doc """
Fits a linear interpolation of the given `(x, y)` points

Expand All @@ -41,7 +57,7 @@ defmodule Scholar.Interpolation.Linear do
]
),
x: Nx.tensor(
[0, 1]
[0, 1, 2]
)
}
"""
Expand Down Expand Up @@ -79,7 +95,7 @@ defmodule Scholar.Interpolation.Linear do

coefficients = Nx.stack([a, b], axis: 1)

%__MODULE__{coefficients: coefficients, x: x0}
%__MODULE__{coefficients: coefficients, x: x}
end

@doc """
Expand All @@ -97,24 +113,120 @@ defmodule Scholar.Interpolation.Linear do
[2.0, 6.0]
]
)

iex> x = Nx.iota({5})
iex> y = Nx.tensor([2.0, 0.0, 1.0, 3.0, 4.0])
iex> model = Scholar.Interpolation.Linear.fit(x, y)
iex> target_x = Nx.tensor([-2, -1, 1.25, 3, 3.25, 5.0])
iex> Scholar.Interpolation.Linear.predict(model, target_x, left: 0.0, right: 10.0)
#Nx.Tensor<
f32[6]
[0.0, 0.0, 0.25, 3.0, 3.25, 10.0]
>
"""
defn predict(%__MODULE__{x: x, coefficients: coefficients} = _model, target_x) do
original_shape = Nx.shape(target_x)
deftransform predict(model, target_x, opts \\ []) do
predict_n(model, target_x, NimbleOptions.validate!(opts, @opts_schema))
end

defnp predict_n(%__MODULE__{x: x, coefficients: coefficients} = _model, target_x, opts) do
shape = Nx.shape(target_x)

target_x = Nx.flatten(target_x)

idx_selector = Nx.new_axis(target_x, 1) >= x
indices = Nx.argsort(target_x)

left_bound = x[0]
right_bound = x[-1]

target_x = Nx.sort(target_x)
res = Nx.broadcast(Nx.tensor(0, type: to_float_type(target_x)), {Nx.axis_size(target_x, 0)})

# while with smaller than left_bound
{{res, i}, _} =
while {{res, i = 0}, {x, coefficients, left_bound, target_x}},
check_cond_left(target_x, i, left_bound) do
val =
case opts[:left] do
nil ->
coefficients[0][0] * Nx.take(target_x, i) + coefficients[0][1]

_ ->
opts[:left]
end

res = Nx.indexed_put(res, Nx.new_axis(i, -1), val)
{{res, i + 1}, {x, coefficients, left_bound, target_x}}
end

{{res, i}, _} =
while {{res, i}, {x, right_bound, coefficients, target_x, j = 0}},
check_cond_right(target_x, i, right_bound) do
{j, _} =
while {j, {i, x, target_x}},
j < Nx.axis_size(x, 0) and Nx.take(x, j) < Nx.take(target_x, i) do
{j + 1, {i, x, target_x}}
end

res =
Nx.indexed_put(
res,
Nx.new_axis(i, -1),
coefficients[Nx.max(j - 1, 0)][0] * Nx.take(target_x, i) +
coefficients[Nx.max(j - 1, 0)][1]
)

i = i + 1

{{res, i}, {x, right_bound, coefficients, target_x, j}}
end

{res, i}

# while with bigger than right_bound

{res, _} =
while {res, {x, coefficients, target_x, i}},
i < Nx.axis_size(target_x, 0) do
val =
case opts[:right] do
nil ->
coefficients[-1][0] * Nx.take(target_x, i) + coefficients[-1][1]

_ ->
opts[:right]
end

res = Nx.indexed_put(res, Nx.new_axis(i, -1), val)
{res, {x, coefficients, target_x, i + 1}}
end

res = Nx.take(res, indices)
Nx.reshape(res, shape)
end

idx_poly = Nx.argmax(idx_selector, axis: 1, tie_break: :high)
defnp check_cond_left(target_x, i, left_bound) do
cond do
i >= Nx.axis_size(target_x, 0) ->
Nx.u8(0)

idx_poly = Nx.select(Nx.all(idx_selector == 0, axes: [1]), 0, idx_poly)
Nx.take(target_x, i) >= left_bound ->
Nx.u8(0)

coef_poly = Nx.take(coefficients, idx_poly)
true ->
Nx.u8(1)
end
end

x_poly = Nx.stack([target_x, Nx.broadcast(1, target_x)], axis: 1)
defnp check_cond_right(target_x, i, right_bound) do
cond do
i >= Nx.axis_size(target_x, 0) ->
Nx.u8(0)

result = Nx.dot(x_poly, [1], [0], coef_poly, [1], [0])
Nx.take(target_x, i) > right_bound ->
Nx.u8(0)

Nx.reshape(result, original_shape)
true ->
Nx.u8(1)
end
end
end