diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex index bfb584af..97ecbaab 100644 --- a/lib/scholar/neighbors/brute_knn.ex +++ b/lib/scholar/neighbors/brute_knn.ex @@ -25,7 +25,7 @@ defmodule Scholar.Neighbors.BruteKNN do type: {:or, [{:custom, Scholar.Options, :metric, []}, {:fun, 2}]}, default: {:minkowski, 2}, doc: ~S""" - The function that measures distance between two points. Possible values: + The function that measures the distance between two points. Possible values: * `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`) we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric. diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index 1228a935..3c3bbead 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -318,11 +318,20 @@ defmodule Scholar.Neighbors.KDTree do """ deftransform predict(tree, data) do if Nx.rank(data) != 2 do - raise ArgumentError, "Input data must be a 2D tensor" + raise ArgumentError, + """ + expected query tensor to have shape {num_queries, num_features}, \ + got tensor with shape: #{inspect(Nx.shape(data))} + """ end - if Nx.axis_size(data, -1) != Nx.axis_size(tree.data, -1) do - raise ArgumentError, "Input data must have the same number of features as the training data" + if Nx.axis_size(tree.data, 1) != Nx.axis_size(data, 1) do + raise ArgumentError, + """ + expected query tensor to have same number of features as tensor used to fit the tree, \ + got #{inspect(Nx.axis_size(data, 1))} \ + and #{inspect(Nx.axis_size(tree.data, 1))} + """ end predict_n(tree, data) diff --git a/lib/scholar/neighbors/knn_classifier.ex b/lib/scholar/neighbors/knn_classifier.ex new file mode 100644 index 00000000..d3acb5c3 --- /dev/null +++ b/lib/scholar/neighbors/knn_classifier.ex @@ -0,0 +1,287 @@ +defmodule Scholar.Neighbors.KNNClassifier do + @moduledoc """ + K-Nearest Neighbors Classifier. + + Performs classifiction by computing the (weighted) majority voting among k-nearest neighbors. + """ + + import Nx.Defn + import Scholar.Shared + require Nx + + @derive {Nx.Container, keep: [:num_classes, :weights], containers: [:algorithm, :labels]} + defstruct [:algorithm, :num_classes, :weights, :labels] + + opts = [ + algorithm: [ + type: :atom, + default: :brute, + doc: """ + Algorithm used to compute the k-nearest neighbors. Possible values: + + * `:brute` - Brute-force search. See `Scholar.Neighbors.BruteKNN` for more details. + + * `:kd_tree` - k-d tree. See `Scholar.Neighbors.KDTree` for more details. + + * `:random_projection_forest` - Random projection forest. See `Scholar.Neighbors.RandomProjectionForest` for more details. + + * Module implementing `fit(data, opts)` and `predict(model, query)`. predict/2 must return tuple containing indices + of k-nearest neighbors of query points as well as distances between query points and their k-nearest neighbors. + """ + ], + num_classes: [ + required: true, + type: :pos_integer, + doc: "The number of possible classes." + ], + weights: [ + type: {:in, [:uniform, :distance]}, + default: :uniform, + doc: """ + Weight function used in prediction. Possible values: + + * `:uniform` - uniform weights. All points in each neighborhood are weighted equally. + + * `:distance` - weight points by the inverse of their distance. in this case, closer neighbors of + a query point will have a greater influence than neighbors which are further away. + """ + ] + ] + + @opts_schema NimbleOptions.new!(opts) + + @doc """ + Fits a k-NN classifier model. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + Algorithm-specific options (e.g. `:num_neighbors`, `:metric`) should be provided together with the classifier options. + + ## Examples + + iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> y = Nx.tensor([0, 0, 0, 1, 1]) + iex> model = Scholar.Neighbors.KNNClassifier.fit(x, y, num_neighbors: 3, num_classes: 2) + iex> model.algorithm + Scholar.Neighbors.BruteKNN.fit(x, num_neighbors: 3) + iex> model.labels + Nx.tensor([0, 0, 0, 1, 1]) + + iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> y = Nx.tensor([0, 0, 0, 1, 1]) + iex> model = Scholar.Neighbors.KNNClassifier.fit(x, y, algorithm: :kd_tree, num_neighbors: 3, metric: {:minkowski, 1}, num_classes: 2) + iex> model.algorithm + Scholar.Neighbors.KDTree.fit(x, num_neighbors: 3, metric: {:minkowski, 1}) + iex> model.labels + Nx.tensor([0, 0, 0, 1, 1]) + + iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> y = Nx.tensor([0, 0, 0, 1, 1]) + iex> key = Nx.Random.key(12) + iex> model = Scholar.Neighbors.KNNClassifier.fit(x, y, algorithm: :random_projection_forest, num_neighbors: 2, num_classes: 2, num_trees: 4, key: key) + iex> model.algorithm + Scholar.Neighbors.RandomProjectionForest.fit(x, num_neighbors: 2, num_trees: 4, key: key) + iex> model.labels + Nx.tensor([0, 0, 0, 1, 1]) + """ + deftransform fit(x, y, opts) do + if Nx.rank(x) != 2 do + raise ArgumentError, + """ + expected x to have shape {num_samples, num_features}, \ + got tensor with shape: #{inspect(Nx.shape(x))} + """ + end + + if Nx.rank(y) != 1 do + raise ArgumentError, + """ + expected y to have shape {num_samples}, \ + got tensor with shape: #{inspect(Nx.shape(y))} + """ + end + + if Nx.axis_size(x, 0) != Nx.axis_size(y, 0) do + raise ArgumentError, + """ + expected x and y to have the same first dimension, \ + got #{Nx.axis_size(x, 0)} and #{Nx.axis_size(y, 0)} + """ + end + + {opts, algorithm_opts} = Keyword.split(opts, [:algorithm, :num_classes, :weights]) + opts = NimbleOptions.validate!(opts, @opts_schema) + + algorithm_module = + case opts[:algorithm] do + :brute -> + Scholar.Neighbors.BruteKNN + + :kd_tree -> + Scholar.Neighbors.KDTree + + :random_projection_forest -> + Scholar.Neighbors.RandomProjectionForest + + module when is_atom(module) -> + module + end + + algorithm = algorithm_module.fit(x, algorithm_opts) + + %__MODULE__{ + algorithm: algorithm, + num_classes: opts[:num_classes], + labels: y, + weights: opts[:weights] + } + end + + @doc """ + Predicts classes using a k-NN classifier model. + + ## Examples + + iex> x_train = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> y_train = Nx.tensor([0, 0, 0, 1, 1]) + iex> model = Scholar.Neighbors.KNNClassifier.fit(x_train, y_train, num_neighbors: 3, num_classes: 2) + iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]]) + iex> Scholar.Neighbors.KNNClassifier.predict(model, x) + Nx.tensor([0, 0, 1]) + """ + defn predict(model, x) do + {neighbors, distances} = compute_knn(model.algorithm, x) + neighbor_labels = Nx.take(model.labels, neighbors) + + case model.weights do + :uniform -> Nx.mode(neighbor_labels, axis: 1) + :distance -> weighted_mode(neighbor_labels, check_weights(distances)) + end + end + + @doc """ + Predicts class probabilities using a k-NN classifier model. + + ## Examples + + iex> x_train = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> y_train = Nx.tensor([0, 0, 0, 1, 1]) + iex> model = Scholar.Neighbors.KNNClassifier.fit(x_train, y_train, num_neighbors: 3, num_classes: 2) + iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]]) + iex> Scholar.Neighbors.KNNClassifier.predict_probability(model, x) + Nx.tensor( + [ + [1.0, 0.0], + [1.0, 0.0], + [0.3333333432674408, 0.6666666865348816] + ] + ) + """ + defn predict_probability(model, x) do + num_samples = Nx.axis_size(x, 0) + type = to_float_type(x) + {neighbors, distances} = compute_knn(model.algorithm, x) + neighbor_labels = Nx.take(model.labels, neighbors) + proba = Nx.broadcast(Nx.tensor(0.0, type: type), {num_samples, model.num_classes}) + + weights = + case model.weights do + :uniform -> Nx.broadcast(1.0, neighbors) + :distance -> check_weights(distances) + end + + indices = + Nx.stack( + [Nx.iota(Nx.shape(neighbor_labels), axis: 0), neighbor_labels], + axis: 2 + ) + |> Nx.flatten(axes: [0, 1]) + + proba = Nx.indexed_add(proba, indices, Nx.flatten(weights)) + normalizer = Nx.sum(proba, axes: [1]) + normalizer = Nx.select(normalizer == 0, 1, normalizer) + proba / Nx.new_axis(normalizer, 1) + end + + deftransformp compute_knn(algorithm, x) do + algorithm.__struct__.predict(algorithm, x) + end + + defnp check_weights(weights) do + zero_mask = weights == 0 + zero_rows = zero_mask |> Nx.any(axes: [1], keep_axes: true) |> Nx.broadcast(weights) + weights = Nx.select(zero_mask, 1, weights) + weights_inv = 1 / weights + Nx.select(zero_rows, Nx.select(zero_mask, 1, 0), weights_inv) + end + + defnp weighted_mode(tensor, weights) do + tensor_size = Nx.size(tensor) + + cond do + tensor_size == 1 -> + Nx.squeeze(tensor, axes: [1]) + + true -> + weighted_mode_general(tensor, weights) + end + end + + defnp weighted_mode_general(tensor, weights) do + {num_samples, num_features} = tensor_shape = Nx.shape(tensor) + + indices = Nx.argsort(tensor, axis: 1) + + sorted = Nx.take_along_axis(tensor, indices, axis: 1) + + size_to_broadcast = {num_samples, 1} + + group_indices = + Nx.concatenate( + [ + Nx.broadcast(0, size_to_broadcast), + Nx.not_equal( + Nx.slice_along_axis(sorted, 0, Nx.axis_size(sorted, 1) - 1, axis: 1), + Nx.slice_along_axis(sorted, 1, Nx.axis_size(sorted, 1) - 1, axis: 1) + ) + ], + axis: 1 + ) + |> Nx.cumulative_sum(axis: 1) + + num_elements = Nx.size(tensor_shape) + + counting_indices = + [ + Nx.shape(group_indices) + |> Nx.iota(axis: 0) + |> Nx.reshape({num_elements, 1}), + Nx.reshape(group_indices, {num_elements, 1}) + ] + |> Nx.concatenate(axis: 1) + + to_add = Nx.flatten(weights) + + indices = + (indices + num_features * Nx.iota(tensor_shape, axis: 0)) + |> Nx.flatten() + + weights = Nx.take(to_add, indices) + + largest_group_indices = + Nx.broadcast(0, sorted) + |> Nx.indexed_add(counting_indices, weights) + |> Nx.argmax(axis: 1, keep_axis: true) + + indices = + largest_group_indices + |> Nx.broadcast(group_indices) + |> Nx.equal(group_indices) + |> Nx.argmax(axis: 1, keep_axis: true) + + res = Nx.take_along_axis(sorted, indices, axis: 1) + Nx.squeeze(res, axes: [1]) + end +end diff --git a/lib/scholar/neighbors/large_vis.ex b/lib/scholar/neighbors/large_vis.ex index 8f98fdbf..076a9c9c 100644 --- a/lib/scholar/neighbors/large_vis.ex +++ b/lib/scholar/neighbors/large_vis.ex @@ -14,7 +14,7 @@ defmodule Scholar.Neighbors.LargeVis do import Nx.Defn import Scholar.Shared require Nx - alias Scholar.Neighbors.RandomProjectionForest, as: Forest + alias Scholar.Neighbors.RandomProjectionForest alias Scholar.Neighbors.Utils opts = [ @@ -23,6 +23,11 @@ defmodule Scholar.Neighbors.LargeVis do type: :pos_integer, doc: "The number of neighbors in the graph." ], + metric: [ + type: {:in, [:squared_euclidean, :euclidean]}, + default: :euclidean, + doc: "The function that measures distance between two points." + ], min_leaf_size: [ type: :pos_integer, doc: """ @@ -63,7 +68,7 @@ defmodule Scholar.Neighbors.LargeVis do iex> key = Nx.Random.key(12) iex> tensor = Nx.iota({5, 2}) - iex> {graph, distances} = Scholar.Neighbors.LargeVis.fit(tensor, num_neighbors: 2, min_leaf_size: 2, num_trees: 3, key: key) + iex> {graph, distances} = Scholar.Neighbors.LargeVis.fit(tensor, num_neighbors: 2, metric: :squared_euclidean, min_leaf_size: 2, num_trees: 3, key: key) iex> graph #Nx.Tensor< u32[5][2] @@ -98,6 +103,13 @@ defmodule Scholar.Neighbors.LargeVis do opts = NimbleOptions.validate!(opts, @opts_schema) k = opts[:num_neighbors] + + metric = + case opts[:metric] do + :euclidean -> &Scholar.Metrics.Distance.euclidean/2 + :squared_euclidean -> &Scholar.Metrics.Distance.squared_euclidean/2 + end + min_leaf_size = opts[:min_leaf_size] || max(10, 2 * k) size = Nx.axis_size(tensor, 0) @@ -108,6 +120,7 @@ defmodule Scholar.Neighbors.LargeVis do tensor, key, num_neighbors: k, + metric: metric, min_leaf_size: min_leaf_size, num_trees: num_trees, num_iters: opts[:num_iters] @@ -116,15 +129,15 @@ defmodule Scholar.Neighbors.LargeVis do defnp fit_n(tensor, key, opts) do forest = - Forest.fit(tensor, + RandomProjectionForest.fit(tensor, num_neighbors: opts[:num_neighbors], min_leaf_size: opts[:min_leaf_size], num_trees: opts[:num_trees], key: key ) - {graph, _} = Forest.predict(forest, tensor) - expand(graph, tensor, num_iters: opts[:num_iters]) + {graph, _} = RandomProjectionForest.predict(forest, tensor) + expand(graph, tensor, metric: opts[:metric], num_iters: opts[:num_iters]) end defn expand(graph, tensor, opts) do @@ -140,17 +153,20 @@ defmodule Scholar.Neighbors.LargeVis do {tensor, iter = 0} }, iter < num_iters do - {expansion_iter(graph, tensor), {tensor, iter + 1}} + {expansion_iter(graph, tensor, metric: opts[:metric]), {tensor, iter + 1}} end result end - defnp expansion_iter(graph, tensor) do + defnp expansion_iter(graph, tensor, opts) do {size, k} = Nx.shape(graph) candidate_indices = Nx.take(graph, graph) |> Nx.reshape({size, k * k}) candidate_indices = Nx.concatenate([graph, candidate_indices], axis: 1) - Utils.find_neighbors(tensor, tensor, candidate_indices, num_neighbors: k) + Utils.brute_force_search_with_candidates(tensor, tensor, candidate_indices, + num_neighbors: k, + metric: opts[:metric] + ) end end diff --git a/lib/scholar/neighbors/random_projection_forest.ex b/lib/scholar/neighbors/random_projection_forest.ex index 6a558e45..6c235b33 100644 --- a/lib/scholar/neighbors/random_projection_forest.ex +++ b/lib/scholar/neighbors/random_projection_forest.ex @@ -23,7 +23,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do alias Scholar.Neighbors.Utils @derive {Nx.Container, - keep: [:num_neighbors, :depth, :leaf_size, :num_trees], + keep: [:num_neighbors, :metric, :depth, :leaf_size, :num_trees], containers: [:indices, :data, :hyperplanes, :medians]} @enforce_keys [ :num_neighbors, @@ -37,6 +37,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do ] defstruct [ :num_neighbors, + :metric, :depth, :leaf_size, :num_trees, @@ -52,6 +53,11 @@ defmodule Scholar.Neighbors.RandomProjectionForest do type: :pos_integer, doc: "The number of nearest neighbors." ], + metric: [ + type: {:in, [:squared_euclidean, :euclidean]}, + default: :euclidean, + doc: "The function that measures the distance between two points." + ], min_leaf_size: [ type: :pos_integer, doc: "The minumum number of points in the leaf." @@ -107,6 +113,12 @@ defmodule Scholar.Neighbors.RandomProjectionForest do num_neighbors = opts[:num_neighbors] min_leaf_size = opts[:min_leaf_size] + metric = + case opts[:metric] do + :euclidean -> &Scholar.Metrics.Distance.euclidean/2 + :squared_euclidean -> &Scholar.Metrics.Distance.squared_euclidean/2 + end + min_leaf_size = cond do is_nil(min_leaf_size) -> @@ -142,6 +154,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do %__MODULE__{ num_neighbors: num_neighbors, + metric: metric, depth: depth, leaf_size: leaf_size, num_trees: num_trees, @@ -283,7 +296,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do iex> key = Nx.Random.key(12) iex> tensor = Nx.iota({5, 2}) - iex> forest = Scholar.Neighbors.RandomProjectionForest.fit(tensor, num_neighbors: 2, num_trees: 3, key: key) + iex> forest = Scholar.Neighbors.RandomProjectionForest.fit(tensor, num_neighbors: 2, metric: :squared_euclidean, num_trees: 3, key: key) iex> query = Nx.tensor([[3, 4]]) iex> {neighbors, distances} = Scholar.Neighbors.RandomProjectionForest.predict(forest, query) iex> neighbors @@ -305,7 +318,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do if Nx.rank(query) != 2 do raise ArgumentError, """ - expected query tensor to have shape {num_samples, num_features}, \ + expected query tensor to have shape {num_queries, num_features}, \ got tensor with shape: #{inspect(Nx.shape(query))} """ end @@ -313,9 +326,9 @@ defmodule Scholar.Neighbors.RandomProjectionForest do if Nx.axis_size(forest.data, 1) != Nx.axis_size(query, 1) do raise ArgumentError, """ - expected query tensor to have the same dimension as tensor used to grow the forest, \ - got #{inspect(Nx.axis_size(forest.data, 1))} \ - and #{inspect(Nx.axis_size(query, 1))} + expected query tensor to have same number of features as tensor used to grow the forest, \ + got #{inspect(Nx.axis_size(query, 1))} \ + and #{inspect(Nx.axis_size(forest.data, 1))} """ end @@ -323,9 +336,12 @@ defmodule Scholar.Neighbors.RandomProjectionForest do end defnp predict_n(forest, query) do - k = forest.num_neighbors candidate_indices = get_leaves(forest, query) - Utils.find_neighbors(query, forest.data, candidate_indices, num_neighbors: k) + + Utils.brute_force_search_with_candidates(forest.data, query, candidate_indices, + num_neighbors: forest.num_neighbors, + metric: forest.metric + ) end @doc false diff --git a/lib/scholar/neighbors/utils.ex b/lib/scholar/neighbors/utils.ex index f202348b..06acc8c4 100644 --- a/lib/scholar/neighbors/utils.ex +++ b/lib/scholar/neighbors/utils.ex @@ -3,16 +3,20 @@ defmodule Scholar.Neighbors.Utils do import Nx.Defn require Nx - defn find_neighbors(query, data, candidate_indices, opts) do + defn brute_force_search_with_candidates(data, query, candidate_indices, opts) do k = opts[:num_neighbors] + metric = opts[:metric] + dim = Nx.axis_size(data, 1) {size, length} = Nx.shape(candidate_indices) - distances = + x = query |> Nx.new_axis(1) - |> Nx.subtract(Nx.take(data, candidate_indices)) - |> Nx.pow(2) - |> Nx.sum(axes: [2]) + |> Nx.broadcast({size, length, dim}) + |> Nx.vectorize([:query, :candidates]) + + y = Nx.take(data, candidate_indices) |> Nx.vectorize([:query, :candidates]) + distances = metric.(x, y) |> Nx.devectorize() |> Nx.rename(nil) distances = if length > 1 do @@ -55,11 +59,11 @@ defmodule Scholar.Neighbors.Utils do target = Nx.broadcast(Nx.u32(0), {size, length}) samples = Nx.iota({size, length, 1}, axis: 0) - indices = + target_indices = Nx.concatenate([samples, Nx.new_axis(indices, 2)], axis: 2) |> Nx.reshape({size * length, 2}) updates = Nx.iota({size, length}, axis: 1) |> Nx.reshape({size * length}) - Nx.indexed_add(target, indices, updates) + Nx.indexed_add(target, target_indices, updates) end end diff --git a/test/scholar/metrics/neighbors_test.exs b/test/scholar/metrics/neighbors_test.exs new file mode 100644 index 00000000..833928aa --- /dev/null +++ b/test/scholar/metrics/neighbors_test.exs @@ -0,0 +1,5 @@ +defmodule Scholar.Metrics.NeighborsTest do + use ExUnit.Case, async: true + alias Scholar.Metrics.Neighbors + doctest Neighbors +end diff --git a/test/scholar/neighbors/knn_classifier_test.exs b/test/scholar/neighbors/knn_classifier_test.exs new file mode 100644 index 00000000..4df48d77 --- /dev/null +++ b/test/scholar/neighbors/knn_classifier_test.exs @@ -0,0 +1,219 @@ +defmodule Scholar.Neighbors.KNNClassifierTest do + use Scholar.Case, async: true + alias Scholar.Neighbors.KNNClassifier + doctest KNNClassifier + + defp x_train do + Nx.tensor([ + [3, 6, 7, 5], + [9, 8, 5, 4], + [4, 4, 4, 1], + [9, 4, 5, 6], + [6, 4, 5, 7], + [4, 5, 3, 3], + [4, 5, 7, 8], + [9, 4, 4, 5], + [8, 4, 3, 9], + [2, 8, 4, 4] + ]) + end + + defp y_train do + Nx.tensor([0, 1, 1, 1, 1, 1, 1, 1, 0, 0]) + end + + defp x do + Nx.tensor([[4, 3, 8, 4], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]]) + end + + describe "fit" do + test "fit with default parameters - :num_classes set to 2" do + model = KNNClassifier.fit(x_train(), y_train(), num_neighbors: 3, num_classes: 2) + + assert model.algorithm == Scholar.Neighbors.BruteKNN.fit(x_train(), num_neighbors: 3) + assert model.num_classes == 2 + assert model.labels == y_train() + assert model.weights == :uniform + end + + test "fit with k-d tree" do + model = + KNNClassifier.fit(x_train(), y_train(), + algorithm: :kd_tree, + num_classes: 2, + num_neighbors: 3 + ) + + assert model.algorithm == Scholar.Neighbors.KDTree.fit(x_train(), num_neighbors: 3) + assert model.num_classes == 2 + assert model.labels == y_train() + assert model.weights == :uniform + end + + test "fit with random projection forest" do + key = Nx.Random.key(12) + + model = + KNNClassifier.fit(x_train(), y_train(), + algorithm: :random_projection_forest, + num_classes: 2, + num_neighbors: 3, + num_trees: 4, + key: key + ) + + assert model.algorithm == + Scholar.Neighbors.RandomProjectionForest.fit(x_train(), + num_neighbors: 3, + num_trees: 4, + key: key + ) + + assert model.num_classes == 2 + assert model.labels == y_train() + assert model.weights == :uniform + end + end + + describe "predict" do + test "predict with default values" do + model = KNNClassifier.fit(x_train(), y_train(), num_neighbors: 3, num_classes: 2) + labels_pred = KNNClassifier.predict(model, x()) + assert labels_pred == Nx.tensor([1, 1, 0, 1]) + end + + test "predict with k-d tree" do + model = + KNNClassifier.fit(x_train(), y_train(), + algorithm: :kd_tree, + num_classes: 2, + num_neighbors: 3 + ) + + labels_pred = KNNClassifier.predict(model, x()) + assert labels_pred == Nx.tensor([1, 1, 0, 1]) + end + + test "predict with weights set to :distance" do + model = + KNNClassifier.fit(x_train(), y_train(), + num_classes: 2, + num_neighbors: 3, + weights: :distance + ) + + labels_pred = KNNClassifier.predict(model, x()) + assert labels_pred == Nx.tensor([1, 1, 0, 1]) + end + + test "predict with specific metric and weights set to :distance" do + model = + KNNClassifier.fit(x_train(), y_train(), + num_classes: 2, + num_neighbors: 3, + metric: {:minkowski, 1.5}, + weights: :distance + ) + + labels_pred = KNNClassifier.predict(model, x()) + assert labels_pred == Nx.tensor([1, 1, 0, 1]) + end + + test "predict with weights set to :distance and with x that contains sample with zero-distance" do + x = Nx.tensor([[3, 6, 7, 5], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]]) + + model = + KNNClassifier.fit(x_train(), y_train(), + num_classes: 2, + num_neighbors: 3, + weights: :distance + ) + + labels_pred = KNNClassifier.predict(model, x) + assert labels_pred == Nx.tensor([0, 1, 0, 1]) + end + end + + describe "predict_probability" do + test "predict_probability with default values" do + model = KNNClassifier.fit(x_train(), y_train(), num_classes: 2, num_neighbors: 3) + predictions = KNNClassifier.predict_probability(model, x()) + + assert_all_close( + predictions, + Nx.tensor([ + [0.33333333, 0.66666667], + [0.33333333, 0.66666667], + [0.66666667, 0.33333333], + [0.0, 1.0] + ]) + ) + end + + test "predict_probability with weights set to :distance" do + model = + KNNClassifier.fit(x_train(), y_train(), + num_neighbors: 3, + num_classes: 2, + weights: :distance + ) + + predictions = KNNClassifier.predict_probability(model, x()) + + assert_all_close( + predictions, + Nx.tensor([ + [0.40351151, 0.59648849], + [0.31717204, 0.68282796], + [0.7283494, 0.2716506], + [0.0, 1.0] + ]) + ) + end + + test "predict_probability with weights set to :distance and with specific metric" do + model = + KNNClassifier.fit(x_train(), y_train(), + num_classes: 2, + num_neighbors: 3, + weights: :distance, + metric: {:minkowski, 1.5} + ) + + predictions = KNNClassifier.predict_probability(model, x()) + + assert_all_close( + predictions, + Nx.tensor([ + [0.40381038, 0.59618962], + [0.31457406, 0.68542594], + [0.72993802, 0.27006198], + [0.0, 1.0] + ]) + ) + end + + test "predict_probability with weights set to :distance and with x that contains sample with zero-distance" do + x = Nx.tensor([[3, 6, 7, 5], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]]) + + model = + KNNClassifier.fit(x_train(), y_train(), + num_classes: 2, + num_neighbors: 3, + weights: :distance + ) + + predictions = KNNClassifier.predict_probability(model, x) + + assert_all_close( + predictions, + Nx.tensor([ + [1.0, 0.0], + [0.31717204, 0.68282796], + [0.7283494, 0.2716506], + [0.0, 1.0] + ]) + ) + end + end +end