diff --git a/lib/scholar/interpolation/linear.ex b/lib/scholar/interpolation/linear.ex index e6ae850f..297627b9 100644 --- a/lib/scholar/interpolation/linear.ex +++ b/lib/scholar/interpolation/linear.ex @@ -16,12 +16,27 @@ defmodule Scholar.Interpolation.Linear do $$ """ import Nx.Defn + import Scholar.Shared @derive {Nx.Container, containers: [:coefficients, :x]} defstruct [:coefficients, :x] @type t :: %Scholar.Interpolation.Linear{} + opts_schema = [ + left: [ + type: {:or, [:float, :integer]}, + doc: + "Value to return for values in `target_x` smaller that the smallest value in training set" + ], + right: [ + type: {:or, [:float, :integer]}, + doc: + "Value to return for values in `target_x` greater that the greatest value in training set" + ] + ] + + @opts_schema NimbleOptions.new!(opts_schema) @doc """ Fits a linear interpolation of the given `(x, y)` points @@ -41,7 +56,7 @@ defmodule Scholar.Interpolation.Linear do ] ), x: Nx.tensor( - [0, 1] + [0, 1, 2] ) } """ @@ -79,7 +94,7 @@ defmodule Scholar.Interpolation.Linear do coefficients = Nx.stack([a, b], axis: 1) - %__MODULE__{coefficients: coefficients, x: x0} + %__MODULE__{coefficients: coefficients, x: x} end @doc """ @@ -97,24 +112,120 @@ defmodule Scholar.Interpolation.Linear do [2.0, 6.0] ] ) + + iex> x = Nx.iota({5}) + iex> y = Nx.tensor([2.0, 0.0, 1.0, 3.0, 4.0]) + iex> model = Scholar.Interpolation.Linear.fit(x, y) + iex> target_x = Nx.tensor([-2, -1, 1.25, 3, 3.25, 5.0]) + iex> Scholar.Interpolation.Linear.predict(model, target_x, left: 0.0, right: 10.0) + #Nx.Tensor< + f32[6] + [0.0, 0.0, 0.25, 3.0, 3.25, 10.0] + > """ - defn predict(%__MODULE__{x: x, coefficients: coefficients} = _model, target_x) do - original_shape = Nx.shape(target_x) + deftransform predict(model, target_x, opts \\ []) do + predict_n(model, target_x, NimbleOptions.validate!(opts, @opts_schema)) + end + + defnp predict_n(%__MODULE__{x: x, coefficients: coefficients} = _model, target_x, opts) do + shape = Nx.shape(target_x) target_x = Nx.flatten(target_x) - idx_selector = Nx.new_axis(target_x, 1) >= x + indices = Nx.argsort(target_x) + + left_bound = x[0] + right_bound = x[-1] + + target_x = Nx.sort(target_x) + res = Nx.broadcast(Nx.tensor(0, type: to_float_type(target_x)), {Nx.axis_size(target_x, 0)}) + + # while with smaller than left_bound + {{res, i}, _} = + while {{res, i = 0}, {x, coefficients, left_bound, target_x}}, + check_cond_left(target_x, i, left_bound) do + val = + case opts[:left] do + nil -> + coefficients[0][0] * Nx.take(target_x, i) + coefficients[0][1] + + _ -> + opts[:left] + end + + res = Nx.indexed_put(res, Nx.new_axis(i, -1), val) + {{res, i + 1}, {x, coefficients, left_bound, target_x}} + end + + {{res, i}, _} = + while {{res, i}, {x, right_bound, coefficients, target_x, j = 0}}, + check_cond_right(target_x, i, right_bound) do + {j, _} = + while {j, {i, x, target_x}}, + j < Nx.axis_size(x, 0) and Nx.take(x, j) < Nx.take(target_x, i) do + {j + 1, {i, x, target_x}} + end + + res = + Nx.indexed_put( + res, + Nx.new_axis(i, -1), + coefficients[Nx.max(j - 1, 0)][0] * Nx.take(target_x, i) + + coefficients[Nx.max(j - 1, 0)][1] + ) + + i = i + 1 + + {{res, i}, {x, right_bound, coefficients, target_x, j}} + end + + {res, i} + + # while with bigger than right_bound + + {res, _} = + while {res, {x, coefficients, target_x, i}}, + i < Nx.axis_size(target_x, 0) do + val = + case opts[:right] do + nil -> + coefficients[-1][0] * Nx.take(target_x, i) + coefficients[-1][1] + + _ -> + opts[:right] + end + + res = Nx.indexed_put(res, Nx.new_axis(i, -1), val) + {res, {x, coefficients, target_x, i + 1}} + end + + res = Nx.take(res, indices) + Nx.reshape(res, shape) + end - idx_poly = Nx.argmax(idx_selector, axis: 1, tie_break: :high) + defnp check_cond_left(target_x, i, left_bound) do + cond do + i >= Nx.axis_size(target_x, 0) -> + Nx.u8(0) - idx_poly = Nx.select(Nx.all(idx_selector == 0, axes: [1]), 0, idx_poly) + Nx.take(target_x, i) >= left_bound -> + Nx.u8(0) - coef_poly = Nx.take(coefficients, idx_poly) + true -> + Nx.u8(1) + end + end - x_poly = Nx.stack([target_x, Nx.broadcast(1, target_x)], axis: 1) + defnp check_cond_right(target_x, i, right_bound) do + cond do + i >= Nx.axis_size(target_x, 0) -> + Nx.u8(0) - result = Nx.dot(x_poly, [1], [0], coef_poly, [1], [0]) + Nx.take(target_x, i) > right_bound -> + Nx.u8(0) - Nx.reshape(result, original_shape) + true -> + Nx.u8(1) + end end end diff --git a/lib/scholar/linear/isotonic_regression.ex b/lib/scholar/linear/isotonic_regression.ex index e4f641be..22996eb0 100644 --- a/lib/scholar/linear/isotonic_regression.ex +++ b/lib/scholar/linear/isotonic_regression.ex @@ -204,7 +204,7 @@ defmodule Scholar.Linear.IsotonicRegression do iex> Scholar.Linear.IsotonicRegression.predict(model, to_predict) #Nx.Tensor< f32[10] - [1.0, 1.6666667461395264, 2.3333334922790527, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] + [1.0, 1.6666667461395264, 2.3333332538604736, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] > """ defn predict(model, x) do @@ -259,7 +259,7 @@ defmodule Scholar.Linear.IsotonicRegression do ] ), x: Nx.tensor( - [1.0, 4.0, 7.0, 9.0, 10.0] + [1.0, 4.0, 7.0, 9.0, 10.0, 11.0] ) } } diff --git a/test/scholar/linear/isotonic_regression_test.exs b/test/scholar/linear/isotonic_regression_test.exs index 4f21a7d3..a954a0bd 100644 --- a/test/scholar/linear/isotonic_regression_test.exs +++ b/test/scholar/linear/isotonic_regression_test.exs @@ -209,7 +209,7 @@ defmodule Scholar.Linear.IsotonicRegressionTest do [43.0, 44.0, 46.0, 47.0, 55.0, 56.0] ++ [60.0, 61.0, 66.0, 67.0, 71.0, 72.0] ++ [73.0, 76.0, 77.0, 80.0, 81.0, 87.0] ++ - [88.0, 92.0, 93.0, 97.0, 98.0] + [88.0, 92.0, 93.0, 97.0, 98.0, 99.0] ) ) end