diff --git a/CHANGELOG.md b/CHANGELOG.md index 83b28339..0a24a21a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,37 +1,44 @@ # Changelog -## v0.2.1-dev +## v0.2.2-dev -## v0.2.0 (2023-08-16) +## v0.2.1 (2023-08-30) + +### Enhancements + + * Remove `VegaLite.Data` in favour of future use of `Tucan` + * Do not use EXLA at compile time in `Metrics` + +## v0.2.0 (2023-08-29) This version requires Elixir v1.14+. ### Enhancements * Update notebooks - * Add support for :f16 and :bf16 types in SVD - * Add Affinity Propagation - * Add t-SNE - * Add Polynomial Regression - * Replace seeds with Random.key + * Add support for `:f16` and `:bf16` types in `SVD` + * Add `Affinity Propagation` + * Add `t-SNE` + * Add `Polynomial Regression` + * Replace seeds with `Random.key` * Add 'unrolling loops' option - * Add support for custom optimizers in Logistic Regression - * Add Trapezoidal Integration - * Add AUC ROC, AUC, and ROC Curve - * Add Simpson rule Integration + * Add support for custom optimizers in `Logistic Regression` + * Add `Trapezoidal Integration` + * Add `AUC-ROC`, `AUC`, and `ROC Curve` + * Add `Simpson rule integration` * Unify tests - * Add Radius Nearest Neighbors - * Add DBSCAN - * Add classification metrics: Average Precision Score, Balanced Accuracy Score, - Cohen Kappa Score, Brier Score Loss, Zero-One Loss, Top-k Accuracy Score - * Add regression metrics: R2 Score, MSLE, MAPE, Maximum Residual Error - * Add support for axes in Confusion Matrix - * Add support for broadcasting in Metrics.Distances + * Add `Radius Nearest Neighbors` + * Add `DBSCAN` + * Add classification metrics: `Average Precision Score`, `Balanced Accuracy Score`, + `Cohen Kappa Score`, `Brier Score Loss`, `Zero-One Loss`, `Top-k Accuracy Score` + * Add regression metrics: `R2 Score`, `MSLE`, `MAPE`, `Maximum Residual Error` + * Add support for axes in `Confusion Matrix` + * Add support for broadcasting in `Metrics.Distances` * Update CI - * Add Gaussian Mixtures - * Add Model selection functionalities: K-fold, K-fold Cross Validation, Grid Search - * Change structure of metrics in Scholar - * Add a guide with Cross-Validation and Grid Search + * Add `Gaussian Mixtures` + * Add Model selection functionalities: `K-fold`, `K-fold Cross Validation`, `Grid Search` + * Change structure of metrics in `Scholar` + * Add a guide with `Cross-Validation` and `Grid Search` ## v0.1.0 (2023-03-29) diff --git a/README.md b/README.md index c3fa91d1..0533a518 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ Add to your `mix.exs`: ```elixir def deps do [ - {:scholar, "~> 0.1"} + {:scholar, "~> 0.2.1"} ] end ``` @@ -30,7 +30,7 @@ such as EXLA: ```elixir def deps do [ - {:scholar, "~> 0.1"}, + {:scholar, "~> 0.2.1"}, {:exla, ">= 0.0.0"} ] end @@ -51,7 +51,7 @@ To use Scholar inside code notebooks, run: ```elixir Mix.install([ - {:scholar, "~> 0.1"}, + {:scholar, "~> 0.2.1"}, {:exla, ">= 0.0.0"} ]) diff --git a/lib/scholar/cluster/affinity_propagation.ex b/lib/scholar/cluster/affinity_propagation.ex index 58fb967b..caf7a50b 100644 --- a/lib/scholar/cluster/affinity_propagation.ex +++ b/lib/scholar/cluster/affinity_propagation.ex @@ -116,6 +116,7 @@ defmodule Scholar.Cluster.AffinityPropagation do iterations = opts[:iterations] damping_factor = opts[:damping_factor] self_preference = opts[:self_preference] + data = to_float(data) {initial_a, initial_r, s, affinity_matrix} = initialize_matrices(data, self_preference: self_preference) @@ -307,14 +308,12 @@ defmodule Scholar.Cluster.AffinityPropagation do end defnp initialize_similarities(data, opts \\ []) do - {n, dims} = Nx.shape(data) + n = Nx.axis_size(data, 0) self_preference = opts[:self_preference] - t1 = Nx.reshape(data, {1, n, dims}) |> Nx.broadcast({n, n, dims}) - t2 = Nx.reshape(data, {n, 1, dims}) |> Nx.broadcast({n, n, dims}) - dist = - (-1 * Scholar.Metrics.Distance.squared_euclidean(t1, t2, axes: [-1])) - |> Nx.as_type(to_float_type(data)) + norm1 = Nx.sum(data ** 2, axes: [1], keep_axes: true) + norm2 = Nx.transpose(norm1) + dist = -1 * (norm1 + norm2 - 2 * Nx.dot(data, [1], data, [1])) fill_in = cond do diff --git a/lib/scholar/integrate/integrate.ex b/lib/scholar/integrate/integrate.ex index 4213e8fd..0a348062 100644 --- a/lib/scholar/integrate/integrate.ex +++ b/lib/scholar/integrate/integrate.ex @@ -17,7 +17,7 @@ defmodule Scholar.Integrate do keep_axis: [ type: :boolean, default: false, - doc: "If set to true, the axis which is reduced are kept." + doc: "If set to true, the axis which is reduced is kept." ] ] diff --git a/lib/scholar/linear/isotonic_regression.ex b/lib/scholar/linear/isotonic_regression.ex new file mode 100644 index 00000000..0360b995 --- /dev/null +++ b/lib/scholar/linear/isotonic_regression.ex @@ -0,0 +1,522 @@ +defmodule Scholar.Linear.IsotonicRegression do + @moduledoc """ + Isotonic regression is a method of fitting a free-form line to a set of + observations by solving a convex optimization problem. It is a form of + regression analysis that can be used as an alternative to polynomial + regression to fit nonlinear data. + """ + require Nx + import Nx.Defn, except: [transform: 2] + import Scholar.Shared + + @derive { + Nx.Container, + containers: [ + :increasing, + :x_min, + :x_max, + :x_thresholds, + :y_thresholds, + :cutoff_index, + :preprocess + ] + } + defstruct [ + :x_min, + :x_max, + :x_thresholds, + :y_thresholds, + :increasing, + :cutoff_index, + :preprocess + ] + + @type t() :: %__MODULE__{ + x_min: Nx.Tensor.t(), + x_max: Nx.Tensor.t(), + x_thresholds: Nx.Tensor.t(), + y_thresholds: Nx.Tensor.t(), + increasing: Nx.Tensor.t(), + cutoff_index: Nx.Tensor.t(), + preprocess: Tuple.t() | Scholar.Interpolation.Linear.t() + } + + opts = [ + y_min: [ + type: :float, + doc: """ + Lower bound on the lowest predicted value. If if not provided, the lower bound + is set to `Nx.Constant.neg_infinity()`. + """ + ], + y_max: [ + type: :float, + doc: """ + Upper bound on the highest predicted value. If if not provided, the lower bound + is set to `Nx.Constant.infinity()`. + """ + ], + increasing: [ + type: {:in, [:auto, true, false]}, + default: :auto, + doc: """ + Whether the isotonic regression should be fit with the constraint that the + function is monotonically increasing. If `false`, the constraint is that + the function is monotonically decreasing. If `:auto`, the constraint is + determined automatically based on the data. + """ + ], + out_of_bounds: [ + type: {:in, [:clip, :nan]}, + default: :nan, + doc: """ + How to handle out-of-bounds points. If `:clip`, out-of-bounds points are + mapped to the nearest valid value. If `:nan`, out-of-bounds points are + replaced with `Nx.Constant.nan()`. + """ + ], + sample_weights: [ + type: {:custom, Scholar.Options, :weights, []}, + doc: """ + The weights for each observation. If not provided, + all observations are assigned equal weight. + """ + ] + ] + + @opts_schema NimbleOptions.new!(opts) + + @doc """ + Fits a isotonic regression model for sample inputs `x` and + sample targets `y`. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Return Values + + The function returns a struct with the following parameters: + + * `:x_min` - Minimum value of input tensor `x`. + + * `:x_max` - Maximum value of input tensor `x`. + + * `:x_thresholds` - Thresholds used for predictions. + + * `:y_thresholds` - Predicted values associated with each threshold. + + * `:increasing` - Whether the isotonic regression is increasing. + + * `:cutoff_index` - The index of the last valid threshold. Rest elements are placeholders + for the sake of preserving shape of tensor. + + * `:preprocess` - Interpolation function to be applied on input tensor `x`. Before `preprocess/1` + is applied it is set to {} + + ## Examples + + iex> x = Nx.tensor([1, 4, 7, 9, 10, 11]) + iex> y = Nx.tensor([1, 3, 6, 8, 9, 10]) + iex> Scholar.Linear.IsotonicRegression.fit(x, y) + %Scholar.Linear.IsotonicRegression{ + x_min: Nx.tensor( + 1.0 + ), + x_max: Nx.tensor( + 11.0 + ), + x_thresholds: Nx.tensor( + [1.0, 4.0, 7.0, 9.0, 10.0, 11.0] + ), + y_thresholds: Nx.tensor( + [1.0, 3.0, 6.0, 8.0, 9.0, 10.0] + ), + increasing: Nx.u8(1), + cutoff_index: Nx.tensor( + 5 + ), + preprocess: {} + } + """ + deftransform fit(x, y, opts \\ []) do + opts = NimbleOptions.validate!(opts, @opts_schema) + + opts = + [ + sample_weights_flag: opts[:sample_weights] != nil + ] ++ + opts + + {sample_weights, opts} = Keyword.pop(opts, :sample_weights, 1.0) + x_type = to_float_type(x) + x = to_float(x) + y = to_float(y) + + sample_weights = + if Nx.is_tensor(sample_weights), + do: Nx.as_type(sample_weights, x_type), + else: Nx.tensor(sample_weights, type: x_type) + + sample_weights = Nx.broadcast(sample_weights, {Nx.axis_size(y, 0)}) + + {increasing, opts} = Keyword.pop(opts, :increasing) + + increasing = + case increasing do + :auto -> + check_increasing(x, y) + + true -> + Nx.u8(1) + + false -> + Nx.u8(0) + end + + fit_n(x, y, sample_weights, increasing, opts) + end + + defnp fit_n(x, y, sample_weights, increasing, opts) do + {x_min, x_max, x_unique, y, index_cut} = build_y(x, y, sample_weights, increasing, opts) + + %__MODULE__{ + x_min: x_min, + x_max: x_max, + x_thresholds: x_unique, + y_thresholds: y, + increasing: increasing, + cutoff_index: index_cut, + preprocess: {} + } + end + + @doc """ + Makes predictions with the given `model` on input `x` and interpolating `function`. + + ## Examples + + iex> x = Nx.tensor([1, 4, 7, 9, 10, 11]) + iex> y = Nx.tensor([1, 3, 6, 8, 9, 10]) + iex> model = Scholar.Linear.IsotonicRegression.fit(x, y) + iex> model = Scholar.Linear.IsotonicRegression.preprocess(model) + iex> to_predict = Nx.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + iex> Scholar.Linear.IsotonicRegression.predict(model, to_predict) + #Nx.Tensor< + f32[10] + [1.0, 1.6666667461395264, 2.3333334922790527, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0] + > + """ + defn predict(model, x) do + check_input_shape(x) + check_preprocess(model) + + x = Nx.flatten(x) + x = Nx.clip(x, model.x_min, model.x_max) + + Scholar.Interpolation.Linear.predict( + model.preprocess, + x + ) + end + + @doc """ + Preprocesses the `model` for prediction. + + Returns an updated `model`. + + ## Examples + + iex> x = Nx.tensor([1, 4, 7, 9, 10, 11]) + iex> y = Nx.tensor([1, 3, 6, 8, 9, 10]) + iex> model = Scholar.Linear.IsotonicRegression.fit(x, y) + iex> Scholar.Linear.IsotonicRegression.preprocess(model) + %Scholar.Linear.IsotonicRegression{ + x_min: Nx.tensor( + 1.0 + ), + x_max: Nx.tensor( + 11.0 + ), + x_thresholds: Nx.tensor( + [1.0, 4.0, 7.0, 9.0, 10.0, 11.0] + ), + y_thresholds: Nx.tensor( + [1.0, 3.0, 6.0, 8.0, 9.0, 10.0] + ), + increasing: Nx.u8(1), + cutoff_index: Nx.tensor( + 5 + ), + preprocess: %Scholar.Interpolation.Linear{ + coefficients: Nx.tensor( + [ + [0.6666666865348816, 0.3333333134651184], + [1.0, -1.0], + [1.0, -1.0], + [1.0, -1.0], + [1.0, -1.0] + ] + ), + x: Nx.tensor( + [1.0, 4.0, 7.0, 9.0, 10.0] + ) + } + } + """ + def preprocess(model, trim_duplicates \\ true) do + cutoff = Nx.to_number(model.cutoff_index) + x = model.x_thresholds[0..cutoff] + y = model.y_thresholds[0..cutoff] + + {x, y} = + if trim_duplicates do + keep_mask = + Nx.logical_or( + Nx.not_equal(y[1..-2//1], y[0..-3//1]), + Nx.not_equal(y[1..-2//1], y[2..-1//1]) + ) + + keep_mask = Nx.concatenate([Nx.tensor([1]), keep_mask, Nx.tensor([1])]) + + indices = + Nx.iota({Nx.axis_size(y, 0)}) + |> Nx.add(1) + |> Nx.multiply(keep_mask) + |> Nx.to_flat_list() + + indices = Enum.filter(indices, fn x -> x != 0 end) |> Nx.tensor() |> Nx.subtract(1) + x = Nx.take(x, indices) + y = Nx.take(y, indices) + {x, y} + else + {x, y} + end + + model = %__MODULE__{model | x_thresholds: x} + model = %__MODULE__{model | y_thresholds: y} + + %__MODULE__{ + model + | preprocess: + Scholar.Interpolation.Linear.fit( + model.x_thresholds, + model.y_thresholds + ) + } + end + + deftransform check_preprocess(model) do + if model.preprocess == {} do + raise ArgumentError, + "model has not been preprocessed. " <> + "Please call preprocess/1 on the model before calling predict/2" + end + end + + defnp lexsort(x, y) do + iota = Nx.iota(Nx.shape(x)) + indices = Nx.argsort(x) + y = Nx.take(y, indices) + iota = Nx.take(iota, indices) + indices = Nx.argsort(y) + Nx.take(iota, indices) + end + + defnp build_y(x, y, sample_weights, increasing, opts) do + check_input_shape(x) + x = Nx.flatten(x) + lex_indices = lexsort(y, x) + x = Nx.take(x, lex_indices) + y = Nx.take(y, lex_indices) + sample_weights = Nx.take(sample_weights, lex_indices) + + {x_unique, y_unique, sample_weights_unique, index_cut} = make_unique(x, y, sample_weights) + + y = isotonic_regression(y_unique, sample_weights_unique, index_cut, increasing, opts) + + x_min = + Nx.reduce_min( + Nx.select(Nx.iota(Nx.shape(x_unique)) <= index_cut, x_unique, Nx.Constants.infinity()) + ) + + x_max = + Nx.reduce_max( + Nx.select(Nx.iota(Nx.shape(x_unique)) <= index_cut, x_unique, Nx.Constants.neg_infinity()) + ) + + {x_min, x_max, x_unique, y, index_cut} + end + + defnp isotonic_regression(y, sample_weights, max_size, increasing, opts) do + y_min = + case opts[:y_min] do + nil -> Nx.Constants.neg_infinity() + _ -> opts[:y_min] + end + + y_max = + case opts[:y_max] do + nil -> Nx.Constants.infinity() + _ -> opts[:y_max] + end + + y = contiguous_isotonic_regression(y, sample_weights, max_size, increasing) + + Nx.clip(y, y_min, y_max) + end + + deftransformp check_input_shape(x) do + if not (Nx.rank(x) == 1 or (Nx.rank(x) == 2 and Nx.axis_size(x, 1) == 1)) do + raise ArgumentError, + "Expected input to be a 1d tensor or 2d tensor with axis 1 of size 1, " <> + "got: #{inspect(Nx.shape(x))}" + end + end + + defnp make_unique(x, y, sample_weights) do + x_output = Nx.broadcast(Nx.tensor(0, type: Nx.type(x)), x) + + sample_weights_output = + Nx.broadcast(Nx.tensor(1, type: Nx.type(sample_weights)), sample_weights) + + type_wy = Nx.Type.merge(Nx.type(y), Nx.type(sample_weights_output)) + y_output = Nx.broadcast(Nx.tensor(0, type: type_wy), y) + + current_x = Nx.as_type(x[0], Nx.type(x)) + current_y = Nx.tensor(0, type: type_wy) + current_weight = Nx.tensor(0, type: Nx.type(sample_weights)) + + index = 0 + + {{x_output, y_output, sample_weights_output, index, current_x, current_y, current_weight}, _} = + while {{x_output, y_output, sample_weights_output, index, current_x, current_y, + current_weight}, {j = 0, eps = 1.0e-10, y, x, sample_weights}}, + j < Nx.axis_size(x, 0) do + x_j = x[j] + + {x_output, y_output, sample_weights_output, index, current_x, current_weight, current_y} = + if x_j - current_x >= eps do + x_output = Nx.indexed_put(x_output, Nx.new_axis(index, 0), current_x) + y_output = Nx.indexed_put(y_output, Nx.new_axis(index, 0), current_y / current_weight) + + sample_weights_output = + Nx.indexed_put(sample_weights_output, Nx.new_axis(index, 0), current_weight) + + index = index + 1 + current_x = x_j + current_weight = sample_weights[j] + current_y = y[j] * sample_weights[j] + + {x_output, y_output, sample_weights_output, index, current_x, current_weight, + current_y} + else + current_weight = current_weight + sample_weights[j] + current_y = current_y + y[j] * sample_weights[j] + + {x_output, y_output, sample_weights_output, index, current_x, current_weight, + current_y} + end + + {{x_output, y_output, sample_weights_output, index, current_x, current_y, current_weight}, + {j + 1, eps, y, x, sample_weights}} + end + + x_output = Nx.indexed_put(x_output, Nx.new_axis(index, 0), current_x) + y_output = Nx.indexed_put(y_output, Nx.new_axis(index, 0), current_y / current_weight) + + sample_weights_output = + Nx.indexed_put(sample_weights_output, Nx.new_axis(index, 0), current_weight) + + {x_output, y_output, sample_weights_output, index} + end + + defnp contiguous_isotonic_regression(y, sample_weights, max_size, increasing) do + y_size = if increasing, do: max_size, else: Nx.axis_size(y, 0) - 1 + y = if increasing, do: y, else: Nx.reverse(y) + sample_weights = if increasing, do: sample_weights, else: Nx.reverse(sample_weights) + + target = Nx.iota({Nx.axis_size(y, 0)}, type: :s64) + type_wy = Nx.Type.merge(Nx.type(y), Nx.type(sample_weights)) + i = if increasing, do: 0, else: Nx.axis_size(y, 0) - 1 - max_size + + {{y, target}, _} = + while {{y, target}, + {i, sample_weights, sum_w = Nx.tensor(0, type: Nx.type(sample_weights)), + sum_wy = Nx.tensor(0, type: type_wy), prev_y = Nx.tensor(0, type: type_wy), _k = 0, + terminating_flag = 0, y_size}}, + i < y_size + 1 and not terminating_flag do + k = target[i] + 1 + + cond do + k == y_size + 1 -> + {{y, target}, {i, sample_weights, sum_w, sum_wy, prev_y, k, 1, y_size}} + + y[i] < y[k] -> + i = k + + {{y, target}, {i, sample_weights, sum_w, sum_wy, prev_y, k, terminating_flag, y_size}} + + true -> + sum_wy = sample_weights[i] * y[i] + sum_w = sample_weights[i] + + {y, sample_weights, i, target, sum_w, sum_wy, prev_y, k, _inner_terminating_flag, + y_size} = + while {y, sample_weights, i, target, sum_w, sum_wy, _prev_y = prev_y, k, + inner_terminating_flag = 0, y_size}, + not inner_terminating_flag do + prev_y = y[k] + sum_wy = sum_wy + sample_weights[k] * y[k] + sum_w = sum_w + sample_weights[k] + k = target[k] + 1 + + {y, sample_weights, target, i, inner_terminating_flag} = + if k == y_size + 1 or prev_y < y[k] do + y = Nx.indexed_put(y, Nx.new_axis(i, 0), sum_wy / sum_w) + sample_weights = Nx.indexed_put(sample_weights, Nx.new_axis(i, 0), sum_w) + target = Nx.indexed_put(target, Nx.new_axis(i, 0), k - 1) + target = Nx.indexed_put(target, Nx.new_axis(k - 1, 0), i) + + i = + if i > 0 do + target[i - 1] + else + i + end + + {y, sample_weights, target, i, 1} + else + {y, sample_weights, target, i, 0} + end + + {y, sample_weights, i, target, sum_w, sum_wy, prev_y, k, inner_terminating_flag, + y_size} + end + + {{y, target}, {i, sample_weights, sum_w, sum_wy, prev_y, k, terminating_flag, y_size}} + end + end + + i = if increasing, do: 0, else: Nx.axis_size(y, 0) - 1 - max_size + + {y, _} = + while {y, {target, i, _k = 0, max_size}}, i < max_size + 1 do + k = target[i] + 1 + indices = Nx.iota({Nx.axis_size(y, 0)}) + in_range? = Nx.logical_and(i + 1 <= indices, indices < k) + y = Nx.select(in_range?, y[i], y) + i = k + {y, {target, i, k, max_size}} + end + + if increasing, do: y, else: Nx.reverse(y) + end + + defnp check_increasing(x, y) do + x = Nx.new_axis(x, -1) + y = Nx.new_axis(y, -1) + model = Scholar.Linear.LinearRegression.fit(x, y) + model.coefficients[0][0] >= 0 + end +end diff --git a/lib/scholar/linear/svm.ex b/lib/scholar/linear/svm.ex new file mode 100644 index 00000000..ad9ebd85 --- /dev/null +++ b/lib/scholar/linear/svm.ex @@ -0,0 +1,265 @@ +defmodule Scholar.Linear.SVM do + @moduledoc """ + SVM classifier + + It uses the OvR strategy to handle both binary and multinomial classification. + This implementation uses stochastic gradient descent from default or any other optimizer + available in `Polaris`. This makes it similar to a sklearn SGDClassifier [1]. + It means that on average it will work slower than algorithms that use QP and kernel trick (LIBSVM [2]) or + Coordinate Descent Algorithm (LIBLINEAR [3]). It also cannot use different kernels like in LIBSVM, + but you can use any type of optimizer available in `Polaris`. + + [1] - https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.SGDClassifier.html + [2] - https://www.csie.ntu.edu.tw/~cjlin/libsvm/ + [3] - https://www.csie.ntu.edu.tw/~cjlin/liblinear/ + """ + import Nx.Defn + import Scholar.Shared + + @derive {Nx.Container, containers: [:coefficients, :bias]} + defstruct [:coefficients, :bias] + + opts = [ + num_classes: [ + required: true, + type: :pos_integer, + doc: "number of classes contained in the input tensors." + ], + iterations: [ + type: :pos_integer, + default: 1000, + doc: """ + number of iterations of gradient descent performed inside SVM. + """ + ], + learning_loop_unroll: [ + type: :boolean, + default: false, + doc: ~S""" + If `true`, the learning loop is unrolled. + """ + ], + optimizer: [ + type: {:custom, Scholar.Options, :optimizer, []}, + default: :sgd, + doc: """ + The optimizer name or {init, update} pair of functions (see `Polaris.Optimizers` for more details). + """ + ], + eps: [ + type: :float, + default: 1.0e-8, + doc: + "The convergence tolerance. If the `abs(loss) < size(x) * :eps`, the algorithm is considered to have converged." + ], + loss_fn: [ + type: {:custom, Scholar.Linear.SVM, :loss_function, []}, + default: nil, + doc: """ + The loss function that is used in the algorithm. Functions should take two arguments: `y_predicted` and `y_true`. + If now provided it is set to highe loss without regularization. + """ + ] + ] + + def loss_function(function) do + case function do + function when is_function(function, 2) -> + {:ok, function} + + nil -> + loss_fn = &Scholar.Linear.SVM.hinge_loss(&1, &2, c: 1.0, margin: 10) + + {:ok, loss_fn} + + _ -> + {:error, + "expected loss function to be a function with arity 2, got: #{inspect(function)}"} + end + end + + @opts_schema NimbleOptions.new!(opts) + + @doc """ + Fits an SVM model for sample inputs `x` and sample + targets `y`. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Return Values + + The function returns a struct with the following parameters: + + * `:coefficients` - Coefficient of the features in the decision function. + + * `:bias` - Bias added to the decision function. + + ## Examples + iex> x = Nx.tensor([[1.0, 2.0, 2.1], [3.0, 2.0, 1.4], [4.0, 7.0, 5.3], [3.0, 4.0, 6.3]]) + iex> y = Nx.tensor([1, 0, 1, 1]) + iex> Scholar.Linear.SVM.fit(x, y, num_classes: 2) + %Scholar.Linear.SVM{ + coefficients: Nx.tensor( + [ + [1.6899993419647217, 1.4599995613098145, 1.322001338005066], + [1.4799995422363281, 1.9599990844726562, 2.0080013275146484] + ] + ), + bias: Nx.tensor( + [0.23000003397464752, 0.4799998104572296] + ) + } + """ + deftransform fit(x, y, opts \\ []) do + if Nx.rank(x) != 2 do + raise ArgumentError, + "expected x to have shape {n_samples, n_features}, got tensor with shape: #{inspect(Nx.shape(x))}" + end + + if Nx.rank(y) != 1 do + raise ArgumentError, + "expected y to have shape {n_samples}, got tensor with shape: #{inspect(Nx.shape(y))}" + end + + opts = NimbleOptions.validate!(opts, @opts_schema) + + {optimizer, opts} = Keyword.pop!(opts, :optimizer) + + {optimizer_init_fn, optimizer_update_fn} = + case optimizer do + atom when is_atom(atom) -> apply(Polaris.Optimizers, atom, []) + {f1, f2} -> {f1, f2} + end + + n = Nx.axis_size(x, -1) + num_classes = opts[:num_classes] + + coef = + Nx.broadcast( + Nx.tensor(1.0, type: to_float_type(x)), + {num_classes, n} + ) + + bias = Nx.broadcast(Nx.tensor(0, type: to_float_type(x)), {num_classes}) + + coef_optimizer_state = optimizer_init_fn.(coef) |> as_type(to_float_type(x)) + bias_optimizer_state = optimizer_init_fn.(bias) |> as_type(to_float_type(x)) + + opts = Keyword.put(opts, :optimizer_update_fn, optimizer_update_fn) + + fit_n(x, y, coef, bias, coef_optimizer_state, bias_optimizer_state, opts) + end + + deftransformp as_type(container, target_type) do + Nx.Defn.Composite.traverse(container, fn t -> + type = Nx.type(t) + + if Nx.Type.float?(type) and not Nx.Type.complex?(type) do + Nx.as_type(t, target_type) + else + t + end + end) + end + + # SVM training loop + defnp fit_n(x, y, coef, bias, coef_optimizer_state, bias_optimizer_state, opts) do + iterations = opts[:iterations] + num_classes = opts[:num_classes] + optimizer_update_fn = opts[:optimizer_update_fn] + eps = opts[:eps] + + {{final_coef, final_bias}, _} = + while {{coef, bias}, + {x, iterations, y, coef_optimizer_state, bias_optimizer_state, + has_converged = Nx.broadcast(Nx.u8(0), {num_classes}), eps, iter = 0}}, + iter < iterations and not Nx.all(has_converged) do + # ++++ inner while ++++++ + {{coef, bias, has_converged, coef_optimizer_state, bias_optimizer_state}, _} = + while {{coef, bias, has_converged, coef_optimizer_state, bias_optimizer_state}, + {x, y, iterations, iter, eps, j = 0}}, + j < num_classes do + y_j = y == j + coef_j = Nx.take(coef, j) + bias_j = Nx.take(bias, j) + + {loss, {coef_grad, bias_grad}} = loss_and_grad(coef_j, bias_j, x, y_j, opts[:loss_fn]) + grad = Nx.broadcast(0.0, {num_classes, Nx.axis_size(x, 1)}) + coef_grad = Nx.put_slice(grad, [j, 0], Nx.new_axis(coef_grad, 0)) + + {coef_updates, coef_optimizer_state} = + optimizer_update_fn.(coef_grad, coef_optimizer_state, coef) + + coef = Polaris.Updates.apply_updates(coef, coef_updates) + + grad = Nx.broadcast(0.0, {num_classes}) + bias_grad = Nx.put_slice(grad, [j], Nx.new_axis(bias_grad, 0)) + + {bias_updates, bias_optimizer_state} = + optimizer_update_fn.(bias_grad, bias_optimizer_state, bias) + + bias = Polaris.Updates.apply_updates(bias, bias_updates) + + has_converged_j = Nx.sum(Nx.abs(loss)) < Nx.axis_size(x, 0) * eps + + has_converged = + Nx.indexed_put( + has_converged, + Nx.new_axis(j, -1), + has_converged_j + ) + + {{coef, bias, has_converged, coef_optimizer_state, bias_optimizer_state}, + {x, y, iterations, iter, eps, j + 1}} + end + + # ++++ end inner while ++++++ + + {{coef, bias}, + {x, iterations, y, coef_optimizer_state, bias_optimizer_state, has_converged, eps, + iter + 1}} + end + + %__MODULE__{ + coefficients: final_coef, + bias: final_bias + } + end + + defnp loss_and_grad(coeff, bias, xs, ys, loss_fn) do + value_and_grad({coeff, bias}, fn {coeff, bias} -> + y_pred = predict(coeff, bias, xs) + loss_fn.(y_pred, ys) + end) + end + + defnp predict(coeff, bias, xs) do + Nx.dot(xs, [-1], coeff, [-1]) + bias + end + + defn hinge_loss(y_pred, ys, opts \\ []) do + c = opts[:c] + margin = opts[:margin] + c * Nx.sum(Nx.max(0, margin - y_pred) * ys, axes: [-1]) + end + + @doc """ + Makes predictions with the given model on inputs `x`. + + ## Examples + iex> x = Nx.tensor([[1.0, 2.0], [3.0, 2.0], [4.0, 7.0]]) + iex> y = Nx.tensor([1, 0, 1]) + iex> model = Scholar.Linear.SVM.fit(x, y, num_classes: 2) + iex> Scholar.Linear.SVM.predict(model, Nx.tensor([[-3.0, 5.0]])) + #Nx.Tensor< + s64[1] + [1] + > + """ + defn predict(%__MODULE__{coefficients: coeff, bias: bias}, x) do + score = predict(coeff, bias, x) + Nx.argmax(score, axis: -1) + end +end diff --git a/lib/scholar/manifold/tsne.ex b/lib/scholar/manifold/tsne.ex index 57c4bc1d..d6b27c52 100644 --- a/lib/scholar/manifold/tsne.ex +++ b/lib/scholar/manifold/tsne.ex @@ -2,9 +2,9 @@ defmodule Scholar.Manifold.TSNE do @moduledoc """ TSNE (t-Distributed Stochastic Neighbor Embedding) is a nonlinear dimensionality reduction technique. - ## References + ## Reference - * [t-SNE: t-Distributed Stochastic Neighbor Embedding](http://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf) + * [Van der Maaten, L., & Hinton, G. (2008). Visualizing data using t-SNE. Journal of machine learning research, 9(11).](http://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf) """ import Nx.Defn import Scholar.Shared diff --git a/mix.exs b/mix.exs index cb04e0ff..4908188d 100644 --- a/mix.exs +++ b/mix.exs @@ -30,7 +30,8 @@ defmodule Scholar.MixProject do defp deps do [ {:ex_doc, "~> 0.30", only: :docs}, - {:nx, "~> 0.6"}, + # {:nx, "~> 0.6", override: true}, + {:nx, github: "elixir-nx/nx", sparse: "nx", override: true}, {:nimble_options, "~> 0.5.2 or ~> 1.0"}, {:exla, "~> 0.6", optional: true}, {:polaris, "~> 0.1"} @@ -55,7 +56,8 @@ defmodule Scholar.MixProject do extras: [ "notebooks/linear_regression.livemd", "notebooks/k_means.livemd", - "notebooks/k_nearest_neighbors.livemd" + "notebooks/k_nearest_neighbors.livemd", + "notebooks/cv_gradient_boosting_tree.livemd" ], groups_for_modules: [ Models: [ @@ -68,10 +70,12 @@ defmodule Scholar.MixProject do Scholar.Interpolation.BezierSpline, Scholar.Interpolation.CubicSpline, Scholar.Interpolation.Linear, + Scholar.Linear.IsotonicRegression, Scholar.Linear.LinearRegression, Scholar.Linear.LogisticRegression, Scholar.Linear.PolynomialRegression, Scholar.Linear.RidgeRegression, + Scholar.Linear.SVM, Scholar.Manifold.TSNE, Scholar.NaiveBayes.Complement, Scholar.NaiveBayes.Gaussian, diff --git a/mix.lock b/mix.lock index aa10af5c..a69d4862 100644 --- a/mix.lock +++ b/mix.lock @@ -9,7 +9,7 @@ "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, "nimble_options": {:hex, :nimble_options, "0.5.2", "42703307b924880f8c08d97719da7472673391905f528259915782bb346e0a1b", [:mix], [], "hexpm", "4da7f904b915fd71db549bcdc25f8d56f378ef7ae07dc1d372cbe72ba950dce0"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:hex, :nx, "0.6.0", "37c86eae824125a7e298dd1ee896953d9d671ce3630dcff74c77db17d734a85f", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "e1ad3cc70a5828a1aedb156b71e90863d9623a2dc9b35a5588f8627a07ee6cb4"}, + "nx": {:git, "https://github.com/elixir-nx/nx.git", "a0b7e2e5cc7a62a55cd2e7bbc3e44ba2ac1c996b", [sparse: "nx"]}, "polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "xla": {:hex, :xla, "0.5.0", "fb8a02c02e5a4f4531fbf18a90c325e471037f983f0115d23f510e7dd9a6aa65", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "571ac797a4244b8ba8552ed0295a54397bd896708be51e4da6cbb784f6678061"}, diff --git a/notebooks/cv_gradient_boosting_tree.livemd b/notebooks/cv_gradient_boosting_tree.livemd index 6334479a..9fe6d970 100644 --- a/notebooks/cv_gradient_boosting_tree.livemd +++ b/notebooks/cv_gradient_boosting_tree.livemd @@ -21,7 +21,7 @@ require Explorer.DataFrame, as: DF require Explorer.Series, as: S ``` -In this notebook we are going to work with Medical Cost Personal Datasets to predict medical charges that were applied to each person from the dataset. +In this notebook we are going to work with [Medical Cost Personal Datasets](https://www.kaggle.com/datasets/mirichoi0218/insurance) to predict medical charges that were applied to each person from the dataset. ```elixir data = @@ -61,7 +61,7 @@ Before training our model, we separate the data between train and test sets. Gradient boosting works by sequentially adding predictors to an ensemble, each one correcting its predecessor. Let's go through a simple regression example, using decision trees as the base predictors; this is called _gradient tree boosting_, or _gradient boosted regression trees_ (GBRT). -EXGBoost provides an implementation of gradient boosting trees that accepts a wide range of hyperparameter configurations. For the full list of hyperparameters refer to the EXGBoost docs. +EXGBoost provides an implementation of gradient boosting trees that accepts a wide range of hyperparameter configurations. For the full list of hyperparameters refer to the [EXGBoost](https://hexdocs.pm/exgboost/EXGBoost.html) docs. ```elixir y_pred = @@ -92,9 +92,9 @@ With very little preprocessing we get similar results to the linear regression m ## Evaluating with cross-validation -_k_-fold cross-validation works by creating splits on the training set into _k_ smaller sets, so that the model is trained using $ k - 1 $ splits (folds) as training data and is validated on the remaining part of the data. When using this technique, the performance measure is the average of the values computed in each iteration. +_k_-fold cross-validation works by creating splits on the training set into _k_ smaller sets, so that the model is trained using $k - 1$ splits (folds) as training data and is validated on the remaining part of the data. When using this technique, the performance measure is the average of the values computed in each iteration. - + @@ -153,12 +153,16 @@ S.mean(cv_score) ## Fine-tuning our model with grid search -Finding the right configuration of hyperparameters is an important part of the process when selecting a model. One could try different combinations of hyperparameter values manually, but this can get tedious and time consuming. Instead, we can use the _Grid Search_ method, an iterative process for finding an optimal configuration of hyperparameter values for a given model. +Finding the right configuration of hyperparameters is an important part of the process when selecting a model. One could try different combinations of hyperparameter values manually, but this can get tedious and time consuming. Instead, we can use the _grid search_ method, an iterative process for finding an optimal configuration of hyperparameter values for a given model. First, we need to provide a "grid" of hyperparameter values, so that the algorithm can train and evaluate our model with all possible combinations. ```elixir grid = [ + booster: [:gbtree], + objective: [:reg_squarederror], + evals: [[{x_train, y_train, "training"}]], + verbose_eval: [true], tree_method: [:approx, :exact], max_depth: [2, 3, 4, 5, 6], num_boost_rounds: [20, 50, 90], @@ -173,18 +177,9 @@ gs_scoring_fn = fn x, y, hyperparams -> {x_train, x_test} = x {y_train, y_test} = y - opts = - hyperparams ++ - [ - booster: :gbtree, - objective: :reg_squarederror, - evals: [{x_train, y_train, "training"}], - verbose_eval: true - ] - y_pred = x_train - |> EXGBoost.train(y_train, opts) + |> EXGBoost.train(y_train, hyperparams) |> EXGBoost.predict(x_test) Metrics.mean_square_error(y_test, y_pred) @@ -195,16 +190,38 @@ end Let's run the grid search and see the results. Remember that the more hyperparameter values you add to the grid, the more it will take the algorithm to end. ```elixir -gs_scores = ModelSelection.grid_search(x, y, folding_fn, gs_scoring_fn, grid) +gs_scores = + ModelSelection.grid_search( + x_train, + y_train, + folding_fn, + gs_scoring_fn, + grid + ) ``` The output is a list of maps, each corresponding to an iteration of the grid search algorithm. Every iteration yields a `score` calculated by our scoring function. Let's find the set of hyperparameters that optimizes the score. ```elixir -gs_scores -|> Enum.min_by(fn %{score: score} -> - score - |> Nx.squeeze() - |> Nx.to_number() -end) +best_config = + Enum.min_by(gs_scores, fn %{score: score} -> + score + |> Nx.squeeze() + |> Nx.to_number() + end) +``` + +Finally we train and evaluate a model using the best hyperparameter configuration found by grid search. + +```elixir +%{hyperparameters: opts} = best_config + +model = EXGBoost.train(x_train, y_train, opts) +y_pred = EXGBoost.predict(model, x_test) + +rmse = + Metrics.mean_square_error(y_test, y_pred) + |> Nx.sqrt() + +"RMSE: #{Nx.to_number(rmse)}" ``` diff --git a/test/scholar/linear/isotonic_regression_test.exs b/test/scholar/linear/isotonic_regression_test.exs new file mode 100644 index 00000000..4f21a7d3 --- /dev/null +++ b/test/scholar/linear/isotonic_regression_test.exs @@ -0,0 +1,232 @@ +defmodule Scholar.Linear.IsotonicRegressionTest do + use Scholar.Case, async: true + alias Scholar.Linear.IsotonicRegression + doctest IsotonicRegression + + def y do + Nx.tensor( + [-6.0, 31.65735903, 68.93061443, 86.31471806, 97.47189562] ++ + [48.58797346, 130.29550745, 74.97207708, 95.86122887, 152.12925465] ++ + [139.89476364, 162.24533249, 166.24746787, 93.95286648, 143.40251006] ++ + [153.62943611, 130.6606672, 181.51858789, 143.22194896, 187.78661368] ++ + [183.22612189, 141.55212267, 131.7747108, 185.90269152, 182.94379124] ++ + [121.9048269, 134.7918433, 196.61022551, 187.3647915, 199.05986908] ++ + [168.69936022, 187.28679514, 206.82537807, 225.31802623, 215.76740307] ++ + [178.17594692, 159.54589563, 150.87930799, 152.17808231, 148.44397271] ++ + [174.67860334, 168.88348091, 203.06000578, 148.2094817, 197.33312449] ++ + [173.43206982, 173.50738009, 217.56005055, 167.59101491, 180.60115027] ++ + [221.59128164, 202.56218593, 176.51459568, 183.44920233, 150.36665926] ++ + [151.26758454, 188.15256339, 206.02215053, 158.8768722, 192.71722811] ++ + [172.54369321, 235.35671925, 161.15673632, 199.94415417, 216.71936349] ++ + [190.4827371, 161.23463097, 225.97538526, 202.70532523, 219.4247621] ++ + [198.13399385, 174.83330595, 210.52297206, 247.20325466, 256.87440568] ++ + [166.53666701, 181.19027109, 266.83544133, 221.47239262, 181.10133173] ++ + [211.72245773, 254.33596236, 245.94203039, 239.54083994, 178.13256282] ++ + [240.71736481, 220.29540593, 176.86684072, 250.43181849, 226.99048352] ++ + [253.54297533, 191.08942885, 196.62997466, 276.16473911, 235.69384458] ++ + [201.21740957, 257.73554893, 192.24837393, 264.75599251, 228.2585093] + ) + end + + describe "fit" do + test "fit - all defaults" do + n = 100 + x = Nx.iota({n}) + + y = y() + + model = IsotonicRegression.fit(x, y) + assert model.x_min == Nx.tensor(0.0) + assert model.x_max == Nx.tensor(99.0) + assert model.x_thresholds == Nx.iota({100}, type: :f32) + + assert_all_close( + model.y_thresholds, + Nx.tensor( + [-6.0, 31.657358169555664, 68.93061065673828, 77.45819854736328] ++ + [77.45819854736328, 77.45819854736328, 100.37627410888672, 100.37627410888672] ++ + [100.37627410888672, 142.77029418945312, 142.77029418945312, 142.77029418945312] ++ + [142.77029418945312, 142.77029418945312, 142.77029418945312, 142.77029418945312] ++ + [142.77029418945312, 159.4623260498047, 159.4623260498047, 159.4623260498047] ++ + [159.4623260498047, 159.4623260498047, 159.4623260498047, 159.4623260498047] ++ + [159.4623260498047, 159.4623260498047, 159.4623260498047, 180.6462860107422] ++ + [180.6462860107422, 180.6462860107422, 180.6462860107422, 180.646286010742] ++ + [180.6462860107422, 180.6462860107422, 180.6462860107422, 180.6462860107422] ++ + [180.6462860107422, 180.6462860107422, 180.6462860107422, 180.6462860107422] ++ + [180.6462860107422, 180.6462860107422, 180.6462860107422, 180.6462860107422] ++ + [181.4241943359375, 181.4241943359375, 181.4241943359375, 183.5004119873047] ++ + [183.5004119873047, 183.5004119873047, 183.5004119873047, 183.5004119873047] ++ + [183.5004119873047, 183.5004119873047, 183.5004119873047, 183.5004119873047] ++ + [183.66250610351562, 183.66250610351562, 183.66250610351562, 183.66250610351562] ++ + [183.66250610351562, 194.1490478515625, 194.1490478515625, 194.1490478515625] ++ + [194.1490478515625, 194.1490478515625, 194.1490478515625, 204.21456909179688] ++ + [204.21456909179688, 204.21456909179688, 204.21456909179688, 204.21456909179688] ++ + [210.52296447753906, 212.95114135742188, 212.95114135742188, 212.95114135742188] ++ + [212.95114135742188, 220.2829132080078, 220.2829132080078, 220.2829132080078] ++ + [220.2829132080078, 222.26158142089844, 222.26158142089844, 222.26158142089844] ++ + [222.26158142089844, 222.26158142089844, 222.26158142089844, 222.26158142089844] ++ + [223.7369384765625, 223.7369384765625, 223.7369384765625, 223.7369384765625] ++ + [223.7369384765625, 232.61196899414062, 232.61196899414062, 232.61196899414062] ++ + [232.61196899414062, 232.61196899414062, 246.5072479248047, 246.5072479248047] + ) + ) + + assert model.increasing == Nx.u8(1) + assert model.cutoff_index == Nx.tensor(99) + assert model.preprocess == {} + end + + test "fit with sample_weights" do + x = Nx.tensor([2.0, 2.0, 3.0, 4.0, 5.0]) + y = Nx.tensor([2.0, 3.0, 7.0, 8.0, 9.0]) + sample_weights = Nx.tensor([1, 3, 2, 7, 4]) + model = IsotonicRegression.fit(x, y, sample_weights: sample_weights) + assert model.x_min == Nx.tensor(2.0) + assert model.x_max == Nx.tensor(5.0) + assert model.x_thresholds == Nx.tensor([2.0, 3.0, 4.0, 5.0, 0.0]) + assert_all_close(model.y_thresholds, Nx.tensor([2.75, 7.0, 8.0, 9.0, 0])) + + assert model.increasing == Nx.u8(1) + assert model.cutoff_index == Nx.tensor(3) + assert model.preprocess == {} + end + + test "fit with sample_weights and :increasing? set to false" do + x = Nx.tensor([2.0, 2.0, 3.0, 4.0, 5.0, 5.0, 6.0]) + y = Nx.tensor([11, 12, 9, 7, 5, 4, 2]) + sample_weights = Nx.tensor([1, 3, 2, 7, 4, 2, 1]) + + model = + Scholar.Linear.IsotonicRegression.fit(x, y, + sample_weights: sample_weights, + increasing: false + ) + + assert model.x_min == Nx.tensor(2.0) + assert model.x_max == Nx.tensor(6.0) + assert model.x_thresholds == Nx.tensor([2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0]) + + assert_all_close( + model.y_thresholds, + Nx.tensor([11.75, 9.0, 7.0, 4.666666507720947, 2.0, 0.0, 0.0]) + ) + + assert model.increasing == Nx.u8(0) + assert model.cutoff_index == Nx.tensor(4) + assert model.preprocess == {} + end + + test "fit with sample_weights and :increasing? as default (:auto)" do + x = Nx.tensor([2.0, 2.0, 3.0, 4.0, 5.0, 5.0, 6.0]) + y = Nx.tensor([11, 12, 9, 7, 5, 4, 2]) + sample_weights = Nx.tensor([1, 3, 2, 7, 4, 2, 1]) + + model = Scholar.Linear.IsotonicRegression.fit(x, y, sample_weights: sample_weights) + assert model.increasing == Nx.u8(0) + end + end + + test "preprocess" do + n = 100 + x = Nx.iota({n}) + + y = y() + + model = IsotonicRegression.fit(x, y) + model = IsotonicRegression.preprocess(model) + + assert model.x_thresholds == + Nx.tensor( + [0.0, 1.0, 2.0, 3.0, 5.0, 6.0] ++ + [8.0, 9.0, 16.0, 17.0, 26.0, 27.0] ++ + [43.0, 44.0, 46.0, 47.0, 55.0, 56.0] ++ + [60.0, 61.0, 66.0, 67.0, 71.0, 72.0] ++ + [73.0, 76.0, 77.0, 80.0, 81.0, 87.0] ++ + [88.0, 92.0, 93.0, 97.0, 98.0, 99.0] + ) + + assert_all_close( + model.y_thresholds, + Nx.tensor( + [-6.0, 31.657358169555664, 68.93061065673828, 77.45819854736328] ++ + [77.45819854736328, 100.37627410888672, 100.37627410888672, 142.77029418945312] ++ + [142.77029418945312, 159.4623260498047, 159.4623260498047, 180.6462860107422] ++ + [180.6462860107422, 181.4241943359375, 181.4241943359375, 183.5004119873047] ++ + [183.5004119873047, 183.66250610351562, 183.66250610351562, 194.1490478515625] ++ + [194.1490478515625, 204.21456909179688, 204.21456909179688, 210.52296447753906] ++ + [212.95114135742188, 212.95114135742188, 220.2829132080078, 220.2829132080078] ++ + [222.26158142089844, 222.26158142089844, 223.7369384765625, 223.7369384765625] ++ + [232.61196899414062, 232.61196899414062, 246.5072479248047, 246.5072479248047] + ) + ) + + assert_all_close( + model.preprocess.coefficients, + Nx.tensor([ + [37.65735626220703, -6.0], + [37.27325439453125, -5.615896224975586], + [8.527587890625, 51.87543487548828], + [0.0, 77.45819854736328], + [22.918075561523438, -37.132179260253906], + [0.0, 100.37627410888672], + [42.394020080566406, -238.77587890625], + [0.0, 142.77029418945312], + [16.692031860351562, -124.30221557617188], + [0.0, 159.4623260498047], + [21.1839599609375, -391.32061767578125], + [0.0, 180.6462860107422], + [0.7779083251953125, 147.19622802734375], + [0.0, 181.4241943359375], + [2.0762176513671875, 85.91818237304688], + [0.0, 183.5004119873047], + [0.1620941162109375, 174.58523559570312], + [0.0, 183.66250610351562], + [10.486541748046875, -445.5299987792969], + [0.0, 194.1490478515625], + [10.065521240234375, -470.17535400390625], + [0.0, 204.21456909179688], + [6.3083953857421875, -243.68148803710938], + [2.4281768798828125, 35.69422912597656], + [0.0, 212.95114135742188], + [7.3317718505859375, -344.2635192871094], + [0.0, 220.2829132080078], + [1.978668212890625, 61.98945617675781], + [0.0, 222.26158142089844], + [1.4753570556640625, 93.905517578125], + [0.0, 223.7369384765625], + [8.875030517578125, -592.765869140625], + [0.0, 232.61196899414062], + [13.895278930664062, -1115.2301025390625], + [0.0, 246.5072479248047] + ]) + ) + + assert_all_close( + model.preprocess.x, + Nx.tensor( + [0.0, 1.0, 2.0, 3.0, 5.0, 6.0] ++ + [8.0, 9.0, 16.0, 17.0, 26.0, 27.0] ++ + [43.0, 44.0, 46.0, 47.0, 55.0, 56.0] ++ + [60.0, 61.0, 66.0, 67.0, 71.0, 72.0] ++ + [73.0, 76.0, 77.0, 80.0, 81.0, 87.0] ++ + [88.0, 92.0, 93.0, 97.0, 98.0] + ) + ) + end + + test "predict" do + n = 100 + x = Nx.iota({n}) + + y = y() + + model = IsotonicRegression.fit(x, y) + model = IsotonicRegression.preprocess(model) + x_to_predict = Nx.tensor([34.64, 23.64, 46.93]) + + assert_all_close( + IsotonicRegression.predict(model, x_to_predict), + Nx.tensor([180.6462860107422, 159.4623260498047, 183.35507202148438]) + ) + end +end diff --git a/test/scholar/linear/logistic_regression_test.exs b/test/scholar/linear/logistic_regression_test.exs index 8138bbdf..549d25d3 100644 --- a/test/scholar/linear/logistic_regression_test.exs +++ b/test/scholar/linear/logistic_regression_test.exs @@ -3,180 +3,14 @@ defmodule Scholar.Linear.LogisticRegressionTest do alias Scholar.Linear.LogisticRegression doctest LogisticRegression - defp iris_data do - key = Nx.Random.key(42) - - x = - Nx.tensor([ - [5.1, 3.5, 1.4, 0.2], - [4.9, 3.0, 1.4, 0.2], - [4.7, 3.2, 1.3, 0.2], - [4.6, 3.1, 1.5, 0.2], - [5.0, 3.6, 1.4, 0.2], - [5.4, 3.9, 1.7, 0.4], - [4.6, 3.4, 1.4, 0.3], - [5.0, 3.4, 1.5, 0.2], - [4.4, 2.9, 1.4, 0.2], - [4.9, 3.1, 1.5, 0.1], - [5.4, 3.7, 1.5, 0.2], - [4.8, 3.4, 1.6, 0.2], - [4.8, 3.0, 1.4, 0.1], - [4.3, 3.0, 1.1, 0.1], - [5.8, 4.0, 1.2, 0.2], - [5.7, 4.4, 1.5, 0.4], - [5.4, 3.9, 1.3, 0.4], - [5.1, 3.5, 1.4, 0.3], - [5.7, 3.8, 1.7, 0.3], - [5.1, 3.8, 1.5, 0.3], - [5.4, 3.4, 1.7, 0.2], - [5.1, 3.7, 1.5, 0.4], - [4.6, 3.6, 1.0, 0.2], - [5.1, 3.3, 1.7, 0.5], - [4.8, 3.4, 1.9, 0.2], - [5.0, 3.0, 1.6, 0.2], - [5.0, 3.4, 1.6, 0.4], - [5.2, 3.5, 1.5, 0.2], - [5.2, 3.4, 1.4, 0.2], - [4.7, 3.2, 1.6, 0.2], - [4.8, 3.1, 1.6, 0.2], - [5.4, 3.4, 1.5, 0.4], - [5.2, 4.1, 1.5, 0.1], - [5.5, 4.2, 1.4, 0.2], - [4.9, 3.1, 1.5, 0.1], - [5.0, 3.2, 1.2, 0.2], - [5.5, 3.5, 1.3, 0.2], - [4.9, 3.1, 1.5, 0.1], - [4.4, 3.0, 1.3, 0.2], - [5.1, 3.4, 1.5, 0.2], - [5.0, 3.5, 1.3, 0.3], - [4.5, 2.3, 1.3, 0.3], - [4.4, 3.2, 1.3, 0.2], - [5.0, 3.5, 1.6, 0.6], - [5.1, 3.8, 1.9, 0.4], - [4.8, 3.0, 1.4, 0.3], - [5.1, 3.8, 1.6, 0.2], - [4.6, 3.2, 1.4, 0.2], - [5.3, 3.7, 1.5, 0.2], - [5.0, 3.3, 1.4, 0.2], - [7.0, 3.2, 4.7, 1.4], - [6.4, 3.2, 4.5, 1.5], - [6.9, 3.1, 4.9, 1.5], - [5.5, 2.3, 4.0, 1.3], - [6.5, 2.8, 4.6, 1.5], - [5.7, 2.8, 4.5, 1.3], - [6.3, 3.3, 4.7, 1.6], - [4.9, 2.4, 3.3, 1.0], - [6.6, 2.9, 4.6, 1.3], - [5.2, 2.7, 3.9, 1.4], - [5.0, 2.0, 3.5, 1.0], - [5.9, 3.0, 4.2, 1.5], - [6.0, 2.2, 4.0, 1.0], - [6.1, 2.9, 4.7, 1.4], - [5.6, 2.9, 3.6, 1.3], - [6.7, 3.1, 4.4, 1.4], - [5.6, 3.0, 4.5, 1.5], - [5.8, 2.7, 4.1, 1.0], - [6.2, 2.2, 4.5, 1.5], - [5.6, 2.5, 3.9, 1.1], - [5.9, 3.2, 4.8, 1.8], - [6.1, 2.8, 4.0, 1.3], - [6.3, 2.5, 4.9, 1.5], - [6.1, 2.8, 4.7, 1.2], - [6.4, 2.9, 4.3, 1.3], - [6.6, 3.0, 4.4, 1.4], - [6.8, 2.8, 4.8, 1.4], - [6.7, 3.0, 5.0, 1.7], - [6.0, 2.9, 4.5, 1.5], - [5.7, 2.6, 3.5, 1.0], - [5.5, 2.4, 3.8, 1.1], - [5.5, 2.4, 3.7, 1.0], - [5.8, 2.7, 3.9, 1.2], - [6.0, 2.7, 5.1, 1.6], - [5.4, 3.0, 4.5, 1.5], - [6.0, 3.4, 4.5, 1.6], - [6.7, 3.1, 4.7, 1.5], - [6.3, 2.3, 4.4, 1.3], - [5.6, 3.0, 4.1, 1.3], - [5.5, 2.5, 4.0, 1.3], - [5.5, 2.6, 4.4, 1.2], - [6.1, 3.0, 4.6, 1.4], - [5.8, 2.6, 4.0, 1.2], - [5.0, 2.3, 3.3, 1.0], - [5.6, 2.7, 4.2, 1.3], - [5.7, 3.0, 4.2, 1.2], - [5.7, 2.9, 4.2, 1.3], - [6.2, 2.9, 4.3, 1.3], - [5.1, 2.5, 3.0, 1.1], - [5.7, 2.8, 4.1, 1.3], - [6.3, 3.3, 6.0, 2.5], - [5.8, 2.7, 5.1, 1.9], - [7.1, 3.0, 5.9, 2.1], - [6.3, 2.9, 5.6, 1.8], - [6.5, 3.0, 5.8, 2.2], - [7.6, 3.0, 6.6, 2.1], - [4.9, 2.5, 4.5, 1.7], - [7.3, 2.9, 6.3, 1.8], - [6.7, 2.5, 5.8, 1.8], - [7.2, 3.6, 6.1, 2.5], - [6.5, 3.2, 5.1, 2.0], - [6.4, 2.7, 5.3, 1.9], - [6.8, 3.0, 5.5, 2.1], - [5.7, 2.5, 5.0, 2.0], - [5.8, 2.8, 5.1, 2.4], - [6.4, 3.2, 5.3, 2.3], - [6.5, 3.0, 5.5, 1.8], - [7.7, 3.8, 6.7, 2.2], - [7.7, 2.6, 6.9, 2.3], - [6.0, 2.2, 5.0, 1.5], - [6.9, 3.2, 5.7, 2.3], - [5.6, 2.8, 4.9, 2.0], - [7.7, 2.8, 6.7, 2.0], - [6.3, 2.7, 4.9, 1.8], - [6.7, 3.3, 5.7, 2.1], - [7.2, 3.2, 6.0, 1.8], - [6.2, 2.8, 4.8, 1.8], - [6.1, 3.0, 4.9, 1.8], - [6.4, 2.8, 5.6, 2.1], - [7.2, 3.0, 5.8, 1.6], - [7.4, 2.8, 6.1, 1.9], - [7.9, 3.8, 6.4, 2.0], - [6.4, 2.8, 5.6, 2.2], - [6.3, 2.8, 5.1, 1.5], - [6.1, 2.6, 5.6, 1.4], - [7.7, 3.0, 6.1, 2.3], - [6.3, 3.4, 5.6, 2.4], - [6.4, 3.1, 5.5, 1.8], - [6.0, 3.0, 4.8, 1.8], - [6.9, 3.1, 5.4, 2.1], - [6.7, 3.1, 5.6, 2.4], - [6.9, 3.1, 5.1, 2.3], - [5.8, 2.7, 5.1, 1.9], - [6.8, 3.2, 5.9, 2.3], - [6.7, 3.3, 5.7, 2.5], - [6.7, 3.0, 5.2, 2.3], - [6.3, 2.5, 5.0, 1.9], - [6.5, 3.0, 5.2, 2.0], - [6.2, 3.4, 5.4, 2.3], - [5.9, 3.0, 5.1, 1.8] - ]) - - y = Nx.concatenate([Nx.broadcast(0, {50}), Nx.broadcast(1, {50}), Nx.broadcast(2, {50})]) - - shuffle = Nx.iota({Nx.axis_size(x, 0)}) - {shuffle, _} = Nx.Random.shuffle(key, shuffle) - x = Nx.take(x, shuffle) - y = Nx.take(y, shuffle) - {x_train, x_test} = Nx.split(x, 120) - {y_train, y_test} = Nx.split(y, 120) - {x_train, x_test, y_train, y_test} - end - test "Iris Data Set - multinomial logistic regression test" do {x_train, x_test, y_train, y_test} = iris_data() model = LogisticRegression.fit(x_train, y_train, num_classes: 3) res = LogisticRegression.predict(model, x_test) - assert Scholar.Metrics.Classification.accuracy(y_test, res) >= 0.965 + accuracy = Scholar.Metrics.Classification.accuracy(res, y_test) + + assert Nx.greater_equal(accuracy, 0.96) == Nx.u8(1) end describe "errors" do diff --git a/test/scholar/linear/svm_test.exs b/test/scholar/linear/svm_test.exs new file mode 100644 index 00000000..d532d456 --- /dev/null +++ b/test/scholar/linear/svm_test.exs @@ -0,0 +1,20 @@ +defmodule Scholar.Linear.SVMTest do + use Scholar.Case, async: true + alias Scholar.Linear.SVM + doctest SVM + + test "Iris Data Set - multinomial classification svm test" do + {x_train, x_test, y_train, y_test} = iris_data() + + loss_fn = fn y_pred, y_true -> + Scholar.Linear.SVM.hinge_loss(y_pred, y_true, c: 1.0, margin: 150) + end + + model = SVM.fit(x_train, y_train, num_classes: 3, loss_fn: loss_fn) + res = SVM.predict(model, x_test) + + accuracy = Scholar.Metrics.Classification.accuracy(res, y_test) + + assert Nx.greater_equal(accuracy, 0.96) == Nx.u8(1) + end +end diff --git a/test/support/scholar_case.ex b/test/support/scholar_case.ex index bcb7d898..62ba6ba5 100644 --- a/test/support/scholar_case.ex +++ b/test/support/scholar_case.ex @@ -25,4 +25,173 @@ defmodule Scholar.Case do """) end end + + def iris_data do + key = Nx.Random.key(42) + + x = + Nx.tensor([ + [5.1, 3.5, 1.4, 0.2], + [4.9, 3.0, 1.4, 0.2], + [4.7, 3.2, 1.3, 0.2], + [4.6, 3.1, 1.5, 0.2], + [5.0, 3.6, 1.4, 0.2], + [5.4, 3.9, 1.7, 0.4], + [4.6, 3.4, 1.4, 0.3], + [5.0, 3.4, 1.5, 0.2], + [4.4, 2.9, 1.4, 0.2], + [4.9, 3.1, 1.5, 0.1], + [5.4, 3.7, 1.5, 0.2], + [4.8, 3.4, 1.6, 0.2], + [4.8, 3.0, 1.4, 0.1], + [4.3, 3.0, 1.1, 0.1], + [5.8, 4.0, 1.2, 0.2], + [5.7, 4.4, 1.5, 0.4], + [5.4, 3.9, 1.3, 0.4], + [5.1, 3.5, 1.4, 0.3], + [5.7, 3.8, 1.7, 0.3], + [5.1, 3.8, 1.5, 0.3], + [5.4, 3.4, 1.7, 0.2], + [5.1, 3.7, 1.5, 0.4], + [4.6, 3.6, 1.0, 0.2], + [5.1, 3.3, 1.7, 0.5], + [4.8, 3.4, 1.9, 0.2], + [5.0, 3.0, 1.6, 0.2], + [5.0, 3.4, 1.6, 0.4], + [5.2, 3.5, 1.5, 0.2], + [5.2, 3.4, 1.4, 0.2], + [4.7, 3.2, 1.6, 0.2], + [4.8, 3.1, 1.6, 0.2], + [5.4, 3.4, 1.5, 0.4], + [5.2, 4.1, 1.5, 0.1], + [5.5, 4.2, 1.4, 0.2], + [4.9, 3.1, 1.5, 0.1], + [5.0, 3.2, 1.2, 0.2], + [5.5, 3.5, 1.3, 0.2], + [4.9, 3.1, 1.5, 0.1], + [4.4, 3.0, 1.3, 0.2], + [5.1, 3.4, 1.5, 0.2], + [5.0, 3.5, 1.3, 0.3], + [4.5, 2.3, 1.3, 0.3], + [4.4, 3.2, 1.3, 0.2], + [5.0, 3.5, 1.6, 0.6], + [5.1, 3.8, 1.9, 0.4], + [4.8, 3.0, 1.4, 0.3], + [5.1, 3.8, 1.6, 0.2], + [4.6, 3.2, 1.4, 0.2], + [5.3, 3.7, 1.5, 0.2], + [5.0, 3.3, 1.4, 0.2], + [7.0, 3.2, 4.7, 1.4], + [6.4, 3.2, 4.5, 1.5], + [6.9, 3.1, 4.9, 1.5], + [5.5, 2.3, 4.0, 1.3], + [6.5, 2.8, 4.6, 1.5], + [5.7, 2.8, 4.5, 1.3], + [6.3, 3.3, 4.7, 1.6], + [4.9, 2.4, 3.3, 1.0], + [6.6, 2.9, 4.6, 1.3], + [5.2, 2.7, 3.9, 1.4], + [5.0, 2.0, 3.5, 1.0], + [5.9, 3.0, 4.2, 1.5], + [6.0, 2.2, 4.0, 1.0], + [6.1, 2.9, 4.7, 1.4], + [5.6, 2.9, 3.6, 1.3], + [6.7, 3.1, 4.4, 1.4], + [5.6, 3.0, 4.5, 1.5], + [5.8, 2.7, 4.1, 1.0], + [6.2, 2.2, 4.5, 1.5], + [5.6, 2.5, 3.9, 1.1], + [5.9, 3.2, 4.8, 1.8], + [6.1, 2.8, 4.0, 1.3], + [6.3, 2.5, 4.9, 1.5], + [6.1, 2.8, 4.7, 1.2], + [6.4, 2.9, 4.3, 1.3], + [6.6, 3.0, 4.4, 1.4], + [6.8, 2.8, 4.8, 1.4], + [6.7, 3.0, 5.0, 1.7], + [6.0, 2.9, 4.5, 1.5], + [5.7, 2.6, 3.5, 1.0], + [5.5, 2.4, 3.8, 1.1], + [5.5, 2.4, 3.7, 1.0], + [5.8, 2.7, 3.9, 1.2], + [6.0, 2.7, 5.1, 1.6], + [5.4, 3.0, 4.5, 1.5], + [6.0, 3.4, 4.5, 1.6], + [6.7, 3.1, 4.7, 1.5], + [6.3, 2.3, 4.4, 1.3], + [5.6, 3.0, 4.1, 1.3], + [5.5, 2.5, 4.0, 1.3], + [5.5, 2.6, 4.4, 1.2], + [6.1, 3.0, 4.6, 1.4], + [5.8, 2.6, 4.0, 1.2], + [5.0, 2.3, 3.3, 1.0], + [5.6, 2.7, 4.2, 1.3], + [5.7, 3.0, 4.2, 1.2], + [5.7, 2.9, 4.2, 1.3], + [6.2, 2.9, 4.3, 1.3], + [5.1, 2.5, 3.0, 1.1], + [5.7, 2.8, 4.1, 1.3], + [6.3, 3.3, 6.0, 2.5], + [5.8, 2.7, 5.1, 1.9], + [7.1, 3.0, 5.9, 2.1], + [6.3, 2.9, 5.6, 1.8], + [6.5, 3.0, 5.8, 2.2], + [7.6, 3.0, 6.6, 2.1], + [4.9, 2.5, 4.5, 1.7], + [7.3, 2.9, 6.3, 1.8], + [6.7, 2.5, 5.8, 1.8], + [7.2, 3.6, 6.1, 2.5], + [6.5, 3.2, 5.1, 2.0], + [6.4, 2.7, 5.3, 1.9], + [6.8, 3.0, 5.5, 2.1], + [5.7, 2.5, 5.0, 2.0], + [5.8, 2.8, 5.1, 2.4], + [6.4, 3.2, 5.3, 2.3], + [6.5, 3.0, 5.5, 1.8], + [7.7, 3.8, 6.7, 2.2], + [7.7, 2.6, 6.9, 2.3], + [6.0, 2.2, 5.0, 1.5], + [6.9, 3.2, 5.7, 2.3], + [5.6, 2.8, 4.9, 2.0], + [7.7, 2.8, 6.7, 2.0], + [6.3, 2.7, 4.9, 1.8], + [6.7, 3.3, 5.7, 2.1], + [7.2, 3.2, 6.0, 1.8], + [6.2, 2.8, 4.8, 1.8], + [6.1, 3.0, 4.9, 1.8], + [6.4, 2.8, 5.6, 2.1], + [7.2, 3.0, 5.8, 1.6], + [7.4, 2.8, 6.1, 1.9], + [7.9, 3.8, 6.4, 2.0], + [6.4, 2.8, 5.6, 2.2], + [6.3, 2.8, 5.1, 1.5], + [6.1, 2.6, 5.6, 1.4], + [7.7, 3.0, 6.1, 2.3], + [6.3, 3.4, 5.6, 2.4], + [6.4, 3.1, 5.5, 1.8], + [6.0, 3.0, 4.8, 1.8], + [6.9, 3.1, 5.4, 2.1], + [6.7, 3.1, 5.6, 2.4], + [6.9, 3.1, 5.1, 2.3], + [5.8, 2.7, 5.1, 1.9], + [6.8, 3.2, 5.9, 2.3], + [6.7, 3.3, 5.7, 2.5], + [6.7, 3.0, 5.2, 2.3], + [6.3, 2.5, 5.0, 1.9], + [6.5, 3.0, 5.2, 2.0], + [6.2, 3.4, 5.4, 2.3], + [5.9, 3.0, 5.1, 1.8] + ]) + + x = Scholar.Preprocessing.standard_scale(x, axes: [-1]) + y = Nx.concatenate([Nx.broadcast(0, {50}), Nx.broadcast(1, {50}), Nx.broadcast(2, {50})]) + + shuffle = Nx.iota({Nx.axis_size(x, 0)}) + {shuffle, _} = Nx.Random.shuffle(key, shuffle) + x = Nx.take(x, shuffle) + y = Nx.take(y, shuffle) + {x_train, x_test} = Nx.split(x, 120) + {y_train, y_test} = Nx.split(y, 120) + {x_train, x_test, y_train, y_test} + end end