From ebdae8f7b2fc9adc6c0a75e4981843009009d6b9 Mon Sep 17 00:00:00 2001 From: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com> Date: Mon, 13 May 2024 18:18:29 +0200 Subject: [PATCH 1/6] Output distances in kdtree (#264) * Output distances in kdtree * Add checks on data in predict --- lib/scholar/neighbors/kd_tree.ex | 138 ++++++++++++++++-------- test/scholar/neighbors/kd_tree_test.exs | 82 ++++++++++---- 2 files changed, 158 insertions(+), 62 deletions(-) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index b8d918a5..1228a935 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -1,6 +1,6 @@ defmodule Scholar.Neighbors.KDTree do @moduledoc """ - Implements a kd-tree, a space-partitioning data structure for organizing points + Implements a k-d tree, a space-partitioning data structure for organizing points in a k-dimensional space. It can be used to predict the K-Nearest Neighbors of a given input. @@ -19,14 +19,13 @@ defmodule Scholar.Neighbors.KDTree do import Nx.Defn import Scholar.Shared - alias Scholar.Metrics.Distance - @derive {Nx.Container, keep: [:levels], containers: [:indices, :data]} - @enforce_keys [:levels, :indices, :data] - defstruct [:levels, :indices, :data] + @derive {Nx.Container, keep: [:levels, :num_neighbors, :metric], containers: [:indices, :data]} + @enforce_keys [:levels, :indices, :data, :num_neighbors, :metric] + defstruct [:levels, :indices, :data, :num_neighbors, :metric] opts = [ - k: [ + num_neighbors: [ type: :pos_integer, default: 3, doc: "The number of neighbors to use by default for `k_neighbors` queries" @@ -45,22 +44,48 @@ defmodule Scholar.Neighbors.KDTree do ] ] - @predict_schema NimbleOptions.new!(opts) + @opts_schema NimbleOptions.new!(opts) @doc """ Builds a KDTree. ## Examples - iex> Scholar.Neighbors.KDTree.fit(Nx.iota({5, 2})) - %Scholar.Neighbors.KDTree{ - data: Nx.iota({5, 2}), - levels: 3, - indices: Nx.u32([3, 1, 4, 0, 2]) - } + iex> tree = Scholar.Neighbors.KDTree.fit(Nx.iota({5, 2})) + iex> tree.data + Nx.tensor( + [ + [0, 1], + [2, 3], + [4, 5], + [6, 7], + [8, 9] + ] + ) + iex> tree.levels + 3 + iex> tree.indices + Nx.u32([3, 1, 4, 0, 2]) """ - deftransform fit(tensor, _opts \\ []) do - %__MODULE__{levels: levels(tensor), indices: fit_n(tensor), data: tensor} + deftransform fit(tensor, opts \\ []) do + opts = NimbleOptions.validate!(opts, @opts_schema) + + metric = + case opts[:metric] do + {:minkowski, p} -> + &Scholar.Metrics.Distance.minkowski(&1, &2, p: p) + + :cosine -> + &Scholar.Metrics.Distance.pairwise_cosine/2 + end + + %__MODULE__{ + levels: levels(tensor), + indices: fit_n(tensor), + data: tensor, + num_neighbors: opts[:num_neighbors], + metric: metric + } end defnp fit_n(tensor) do @@ -247,8 +272,9 @@ defmodule Scholar.Neighbors.KDTree do iex> x = Nx.iota({10, 2}) iex> x_predict = Nx.tensor([[2, 5], [1, 9], [6, 4]]) - iex> kdtree = Scholar.Neighbors.KDTree.fit(x) - iex> Scholar.Neighbors.KDTree.predict(kdtree, x_predict, k: 3) + iex> kdtree = Scholar.Neighbors.KDTree.fit(x, num_neighbors: 3) + iex> {indices, distances} = Scholar.Neighbors.KDTree.predict(kdtree, x_predict) + iex> indices #Nx.Tensor< s64[3][3] [ @@ -257,7 +283,21 @@ defmodule Scholar.Neighbors.KDTree do [2, 3, 1] ] > - iex> Scholar.Neighbors.KDTree.predict(kdtree, x_predict, k: 3, metric: {:minkowski, 1}) + iex> distances + #Nx.Tensor< + f32[3][3] + [ + [2.0, 2.0, 4.4721360206604], + [5.0, 5.385164737701416, 6.082762718200684], + [2.2360680103302, 3.0, 4.123105525970459] + ] + > + + iex> x = Nx.iota({10, 2}) + iex> x_predict = Nx.tensor([[2, 5], [1, 9], [6, 4]]) + iex> kdtree = Scholar.Neighbors.KDTree.fit(x, num_neighbors: 3, metric: {:minkowski, 1}) + iex> {indices, distances} = Scholar.Neighbors.KDTree.predict(kdtree, x_predict) + iex> indices #Nx.Tensor< s64[3][3] [ @@ -266,9 +306,26 @@ defmodule Scholar.Neighbors.KDTree do [2, 3, 1] ] > + iex> distances + #Nx.Tensor< + f32[3][3] + [ + [2.0, 2.0, 6.0], + [7.0, 7.0, 7.0], + [3.0, 3.0, 5.0] + ] + > """ - deftransform predict(tree, data, opts \\ []) do - predict_n(tree, data, NimbleOptions.validate!(opts, @predict_schema)) + deftransform predict(tree, data) do + if Nx.rank(data) != 2 do + raise ArgumentError, "Input data must be a 2D tensor" + end + + if Nx.axis_size(data, -1) != Nx.axis_size(tree.data, -1) do + raise ArgumentError, "Input data must have the same number of features as the training data" + end + + predict_n(tree, data) end defnp sort_by_distances(distances, point_indices) do @@ -276,16 +333,9 @@ defmodule Scholar.Neighbors.KDTree do {Nx.take(distances, indices), Nx.take(point_indices, indices)} end - defnp compute_distance(x1, x2, opts) do - case opts[:metric] do - {:minkowski, 2} -> Distance.squared_euclidean(x1, x2) - {:minkowski, p} -> Distance.minkowski(x1, x2, p: p) - :cosine -> Distance.cosine(x1, x2) - end - end - defnp update_knn(nearest_neighbors, distances, data, indices, curr_node, point, k, opts) do - curr_dist = compute_distance(data[[indices[curr_node]]], point, opts) + metric = opts[:metric] + curr_dist = metric.(data[[indices[curr_node]]], point) if curr_dist < distances[[-1]] do nearest_neighbors = @@ -311,8 +361,8 @@ defmodule Scholar.Neighbors.KDTree do end end - defnp predict_n(tree, point, opts) do - k = opts[:k] + defnp predict_n(tree, point) do + k = tree.num_neighbors node = Nx.as_type(root(), :s64) input_vectorized_axes = point.vectorized_axes @@ -330,6 +380,7 @@ defmodule Scholar.Neighbors.KDTree do indices = tree.indices |> Nx.as_type(:s64) data = tree.data + metric = tree.metric mode = down() i = Nx.s64(0) @@ -345,8 +396,8 @@ defmodule Scholar.Neighbors.KDTree do point ]) - {nearest_neighbors, _} = - while {nearest_neighbors, {node, data, indices, point, distances, visited, i, mode}}, + {{nearest_neighbors, distances}, _} = + while {{nearest_neighbors, distances}, {node, data, indices, point, visited, i, mode}}, node != -1 and i >= 0 do coord_indicator = rem(i, dims) @@ -384,7 +435,7 @@ defmodule Scholar.Neighbors.KDTree do indices, point, k, - opts + metric: metric ) {parent(node), i - 1, visited, nearest_neighbors, distances, up()} @@ -402,14 +453,13 @@ defmodule Scholar.Neighbors.KDTree do indices, point, k, - opts + metric: metric ) if Nx.any( - compute_distance( + metric.( point[[coord_indicator]], - data[[indices[right_child(node)], coord_indicator]], - opts + data[[indices[right_child(node)], coord_indicator]] ) < distances ) do @@ -431,14 +481,13 @@ defmodule Scholar.Neighbors.KDTree do indices, point, k, - opts + metric: metric ) if Nx.any( - compute_distance( + metric.( point[[coord_indicator]], - data[[indices[left_child(node)], coord_indicator]], - opts + data[[indices[left_child(node)], coord_indicator]] ) < distances ) do @@ -457,10 +506,11 @@ defmodule Scholar.Neighbors.KDTree do {node, i, visited, nearest_neighbors, distances, none()} end - {nearest_neighbors, {node, data, indices, point, distances, visited, i, mode}} + {{nearest_neighbors, distances}, {node, data, indices, point, visited, i, mode}} end - Nx.revectorize(nearest_neighbors, input_vectorized_axes, target_shape: {num_points, k}) + {Nx.revectorize(nearest_neighbors, input_vectorized_axes, target_shape: {num_points, k}), + Nx.revectorize(distances, input_vectorized_axes, target_shape: {num_points, k})} end defnp down(), do: Nx.u8(0) diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index af840cea..bf0c9f63 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -1,5 +1,5 @@ defmodule Scholar.Neighbors.KDTreeTest do - use ExUnit.Case, async: true + use Scholar.Case, async: true alias Scholar.Neighbors.KDTree doctest KDTree @@ -20,18 +20,24 @@ defmodule Scholar.Neighbors.KDTreeTest do describe "fit" do test "iota" do - assert %KDTree{levels: 3, indices: indices} = KDTree.fit(Nx.iota({5, 2})) - assert indices == Nx.u32([3, 1, 4, 0, 2]) + tree = KDTree.fit(Nx.iota({5, 2})) + assert tree.levels == 3 + assert tree.indices == Nx.u32([3, 1, 4, 0, 2]) + assert tree.num_neighbors == 3 end test "float" do - assert %KDTree{levels: 4, indices: indices} = KDTree.fit(example() |> Nx.as_type(:f32)) - assert Nx.to_flat_list(indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + tree = KDTree.fit(Nx.as_type(example(), :f32)) + assert tree.levels == 4 + assert Nx.to_flat_list(tree.indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + assert tree.num_neighbors == 3 end test "sample" do - assert %KDTree{levels: 4, indices: indices} = KDTree.fit(example()) - assert Nx.to_flat_list(indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + tree = KDTree.fit(example()) + assert tree.levels == 4 + assert Nx.to_flat_list(tree.indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + assert tree.num_neighbors == 3 end end @@ -57,30 +63,70 @@ defmodule Scholar.Neighbors.KDTreeTest do describe "predict knn" do test "all defaults" do kdtree = KDTree.fit(x()) + {indices, distances} = KDTree.predict(kdtree, x_pred()) - assert KDTree.predict(kdtree, x_pred()) == - Nx.tensor([[0, 6, 4], [5, 2, 9], [0, 9, 2], [5, 2, 7]]) + assert indices == Nx.tensor([[0, 6, 4], [5, 2, 9], [0, 9, 2], [5, 2, 7]]) + + assert_all_close( + distances, + Nx.tensor([ + [3.464101552963257, 4.582575798034668, 4.795831680297852], + [4.242640495300293, 4.690415859222412, 4.795831680297852], + [3.7416574954986572, 5.5677642822265625, 6.0], + [3.872983455657959, 3.872983455657959, 6.164413928985596] + ]) + ) end test "metric set to {:minkowski, 1.5}" do - kdtree = KDTree.fit(x()) + kdtree = KDTree.fit(x(), metric: {:minkowski, 1.5}) + {indices, distances} = KDTree.predict(kdtree, x_pred()) + + assert indices == Nx.tensor([[0, 6, 2], [5, 2, 9], [0, 9, 2], [5, 2, 7]]) - assert KDTree.predict(kdtree, x_pred(), metric: {:minkowski, 1.5}) == - Nx.tensor([[0, 6, 2], [5, 2, 9], [0, 9, 2], [5, 2, 7]]) + assert_all_close( + distances, + Nx.tensor([ + [4.065119743347168, 5.191402435302734, 5.862917423248291], + [5.198591709136963, 5.591182708740234, 5.869683265686035], + [4.334622859954834, 6.35192346572876, 6.9637274742126465], + [4.649191856384277, 4.649191856384277, 7.664907932281494] + ]) + ) end test "k set to 4" do - kdtree = KDTree.fit(x()) + kdtree = KDTree.fit(x(), num_neighbors: 4) + {indices, distances} = KDTree.predict(kdtree, x_pred()) + + assert indices == Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]]) - assert KDTree.predict(kdtree, x_pred(), k: 4) == - Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]]) + assert_all_close( + distances, + Nx.tensor([ + [3.464101552963257, 4.582575798034668, 4.795831680297852, 5.099019527435303], + [4.242640495300293, 4.690415859222412, 4.795831680297852, 7.4833149909973145], + [3.7416574954986572, 5.5677642822265625, 6.0, 6.480740547180176], + [3.872983455657959, 3.872983455657959, 6.164413928985596, 6.78233003616333] + ]) + ) end test "float type data" do - kdtree = KDTree.fit(x() |> Nx.as_type(:f64)) + kdtree = KDTree.fit(x() |> Nx.as_type(:f64), num_neighbors: 4) + {indices, distances} = KDTree.predict(kdtree, x_pred()) + + assert indices == Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]]) - assert KDTree.predict(kdtree, x_pred(), k: 4) == - Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]]) + assert_all_close( + distances, + Nx.tensor([ + [3.464101552963257, 4.582575798034668, 4.795831680297852, 5.099019527435303], + [4.242640495300293, 4.690415859222412, 4.795831680297852, 7.4833149909973145], + [3.7416574954986572, 5.5677642822265625, 6.0, 6.480740547180176], + [3.872983455657959, 3.872983455657959, 6.164413928985596, 6.78233003616333] + ]) + ) end end end From 584b3068b8e744edc270662f56aefd1bf6ba9d53 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Sun, 12 May 2024 20:52:16 +0200 Subject: [PATCH 2/6] Add distance to KDTree.predict/2 --- lib/scholar/neighbors/knn_classifier.ex | 245 ++++++++++++++++++ .../neighbors/random_projection_forest.ex | 2 +- test/scholar/metrics/neighbors_test.exs | 5 + .../scholar/neighbors/knn_classifier_test.exs | 112 ++++++++ 4 files changed, 363 insertions(+), 1 deletion(-) create mode 100644 lib/scholar/neighbors/knn_classifier.ex create mode 100644 test/scholar/metrics/neighbors_test.exs create mode 100644 test/scholar/neighbors/knn_classifier_test.exs diff --git a/lib/scholar/neighbors/knn_classifier.ex b/lib/scholar/neighbors/knn_classifier.ex new file mode 100644 index 00000000..f56b6e58 --- /dev/null +++ b/lib/scholar/neighbors/knn_classifier.ex @@ -0,0 +1,245 @@ +defmodule Scholar.Neighbors.KNNClassifier do + @moduledoc """ + K-Nearest Neighbors Classifier. + + ... + """ + + import Nx.Defn + require Nx + + @derive {Nx.Container, keep: [:algorithm, :num_classes, :weights], containers: [:labels]} + defstruct [:algorithm, :num_classes, :weights, :labels] + + opts = [ + algorithm: [ + type: {:or, [:atom, {:tuple, [:atom, :keyword_list]}]}, + default: :brute, + doc: """ + k-NN algorithm to be used for finding the nearest neighbors. It can be provided as + an atom or a tuple containing an atom and algorithm specific options. + Possible values for the atom: + + * `:brute` - Brute-force search. See `Scholar.Neighbors.BruteKNN` for more details. + + * `:kd_tree` - k-d tree. See `Scholar.Neighbors.KDTree` for more details. + + * `:random_projection_forest` - Random projection forest. See `Scholar.Neighbors.RandomProjectionForest` for more details. + + * Module implementing fit/2 and predict/2. + """ + ], + num_neighbors: [ + required: true, + type: :pos_integer, + doc: "The number of nearest neighbors." + ], + metric: [ + type: {:or, [{:custom, Scholar.Options, :metric, []}, {:fun, 2}]}, + default: {:minkowski, 2}, + doc: """ + The function that measures distance between two points. Possible values: + + * `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`) + we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric. + + * `:cosine` - Cosine metric. + + Keep in mind that different algorithms support different metrics. For more information have a look at the corresponding modules. + """ + ], + num_classes: [ + required: true, + type: :pos_integer, + doc: "The number of possible classes." + ], + weights: [ + type: {:in, [:uniform, :distance]}, + default: :uniform, + doc: """ + Weight function used in prediction. Possible values: + + * `:uniform` - uniform weights. All points in each neighborhood are weighted equally. + + * `:distance` - weight points by the inverse of their distance. in this case, closer neighbors of + a query point will have a greater influence than neighbors which are further away. + """ + ] + ] + + @opts_schema NimbleOptions.new!(opts) + + @doc """ + Fits a k-NN classifier model. + + ## Options + + #{NimbleOptions.docs(@opts_schema)} + + ## Examples + + iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> y = Nx.tensor([0, 0, 0, 1, 1]) + iex> model = Scholar.Neighbors.KNNClassifier.fit(x, y, num_neighbors: 3, num_classes: 2) + iex> model.algorithm + Scholar.Neighbors.BruteKNN.fit(x, num_neighbors: 3) + iex> model.labels + Nx.tensor([0, 0, 0, 1, 1]) + """ + deftransform fit(x, y, opts) do + if Nx.rank(x) != 2 do + raise ArgumentError, + "expected x to have shape {num_samples, num_features}, + got tensor with shape: #{inspect(Nx.shape(x))}" + end + + if Nx.rank(y) != 1 and Nx.axis_size(x, 0) == Nx.axis_size(y, 0) do + raise ArgumentError, + "expected y to have shape {num_samples}, + got tensor with shape: #{inspect(Nx.shape(y))}" + end + + opts = NimbleOptions.validate!(opts, @opts_schema) + + {algorithm_name, algorithm_opts} = + if is_atom(opts[:algorithm]) do + {opts[:algorithm], []} + else + opts[:algorithm] + end + + knn_module = + case algorithm_name do + :brute -> + Scholar.Neighbors.BruteKNN + + :kd_tree -> + Scholar.Neighbors.KDTree + + :random_projection_forest -> + Scholar.Neighbors.RandomProjectionForest + + knn_module when is_atom(knn_module) -> + knn_module + + _ -> + raise ArgumentError, + """ + not supported + """ + end + + # TODO: Maybe raise an error if :num_neighbors or :metric is already in algorithm_opts? + + algorithm_opts = Keyword.put(algorithm_opts, :num_neighbors, opts[:num_neighbors]) + algorithm_opts = Keyword.put(algorithm_opts, :metric, opts[:metric]) + + algorithm = knn_module.fit(x, algorithm_opts) + + %__MODULE__{ + algorithm: algorithm, + num_classes: opts[:num_classes], + labels: y, + weights: opts[:weights] + } + end + + @doc """ + Makes predictions using a k-NN classifier model. + + ## Examples + + iex> x_train = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> y_train = Nx.tensor([0, 0, 0, 1, 1]) + iex> model = Scholar.Neighbors.KNNClassifier.fit(x_train, y_train, num_neighbors: 3, num_classes: 2) + iex> x_test = Nx.tensor([[1, 3], [4, 2], [3, 6]]) + iex> Scholar.Neighbors.KNNClassifier.predict(model, x_test) + Nx.tensor([0, 0, 1]) + """ + deftransform predict(model, x) do + knn_module = model.algorithm.__struct__ + {neighbors, distances} = knn_module.predict(model.algorithm, x) + labels_pred = Nx.take(model.labels, neighbors) + + case model.weights do + :uniform -> Nx.mode(labels_pred, axis: 1) + :distance -> weighted_mode(labels_pred, check_weights(distances)) + end + end + + defnp check_weights(weights) do + zero_mask = weights == 0 + zero_rows = zero_mask |> Nx.any(axes: [1], keep_axes: true) |> Nx.broadcast(weights) + weights = Nx.select(zero_mask, 1, weights) + weights_inv = 1 / weights + Nx.select(zero_rows, Nx.select(zero_mask, 1, 0), weights_inv) + end + + defnp weighted_mode(tensor, weights) do + tensor_size = Nx.size(tensor) + + cond do + tensor_size == 1 -> + Nx.squeeze(tensor, axes: [1]) + + true -> + weighted_mode_general(tensor, weights) + end + end + + defnp weighted_mode_general(tensor, weights) do + {num_samples, num_features} = tensor_shape = Nx.shape(tensor) + + indices = Nx.argsort(tensor, axis: 1) + + sorted = Nx.take_along_axis(tensor, indices, axis: 1) + + size_to_broadcast = {num_samples, 1} + + group_indices = + Nx.concatenate( + [ + Nx.broadcast(0, size_to_broadcast), + Nx.not_equal( + Nx.slice_along_axis(sorted, 0, Nx.axis_size(sorted, 1) - 1, axis: 1), + Nx.slice_along_axis(sorted, 1, Nx.axis_size(sorted, 1) - 1, axis: 1) + ) + ], + axis: 1 + ) + |> Nx.cumulative_sum(axis: 1) + + num_elements = Nx.size(tensor_shape) + + counting_indices = + [ + Nx.shape(group_indices) + |> Nx.iota(axis: 0) + |> Nx.reshape({num_elements, 1}), + Nx.reshape(group_indices, {num_elements, 1}) + ] + |> Nx.concatenate(axis: 1) + + to_add = Nx.flatten(weights) + + indices = + (indices + num_features * Nx.iota(tensor_shape, axis: 0)) + |> Nx.flatten() + + weights = Nx.take(to_add, indices) + + largest_group_indices = + Nx.broadcast(0, sorted) + |> Nx.indexed_add(counting_indices, weights) + |> Nx.argmax(axis: 1, keep_axis: true) + + indices = + largest_group_indices + |> Nx.broadcast(group_indices) + |> Nx.equal(group_indices) + |> Nx.argmax(axis: 1, keep_axis: true) + + res = Nx.take_along_axis(sorted, indices, axis: 1) + Nx.squeeze(res, axes: [1]) + end +end diff --git a/lib/scholar/neighbors/random_projection_forest.ex b/lib/scholar/neighbors/random_projection_forest.ex index 6a558e45..5c89d173 100644 --- a/lib/scholar/neighbors/random_projection_forest.ex +++ b/lib/scholar/neighbors/random_projection_forest.ex @@ -305,7 +305,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do if Nx.rank(query) != 2 do raise ArgumentError, """ - expected query tensor to have shape {num_samples, num_features}, \ + expected query tensor to have shape {num_queries, num_features}, \ got tensor with shape: #{inspect(Nx.shape(query))} """ end diff --git a/test/scholar/metrics/neighbors_test.exs b/test/scholar/metrics/neighbors_test.exs new file mode 100644 index 00000000..833928aa --- /dev/null +++ b/test/scholar/metrics/neighbors_test.exs @@ -0,0 +1,5 @@ +defmodule Scholar.Metrics.NeighborsTest do + use ExUnit.Case, async: true + alias Scholar.Metrics.Neighbors + doctest Neighbors +end diff --git a/test/scholar/neighbors/knn_classifier_test.exs b/test/scholar/neighbors/knn_classifier_test.exs new file mode 100644 index 00000000..5d6731b2 --- /dev/null +++ b/test/scholar/neighbors/knn_classifier_test.exs @@ -0,0 +1,112 @@ +defmodule Scholar.Neighbors.KNNClassifierTest do + use Scholar.Case, async: true + alias Scholar.Neighbors.KNNClassifier + doctest KNNClassifier + + defp x_train do + Nx.tensor([ + [3, 6, 7, 5], + [9, 8, 5, 4], + [4, 4, 4, 1], + [9, 4, 5, 6], + [6, 4, 5, 7], + [4, 5, 3, 3], + [4, 5, 7, 8], + [9, 4, 4, 5], + [8, 4, 3, 9], + [2, 8, 4, 4] + ]) + end + + defp y_train do + Nx.tensor([0, 1, 1, 1, 1, 1, 1, 1, 0, 0]) + end + + defp x do + Nx.tensor([[4, 3, 8, 4], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]]) + end + + describe "fit" do + test "fit with default parameters - :num_classes set to 2" do + model = KNNClassifier.fit(x_train(), y_train(), num_neighbors: 3, num_classes: 2) + + assert model.algorithm == Scholar.Neighbors.BruteKNN.fit(x_train(), num_neighbors: 3) + assert model.num_classes == 2 + assert model.labels == y_train() + assert model.weights == :uniform + end + + test "fit with k-d tree" do + model = + KNNClassifier.fit(x_train(), y_train(), + algorithm: :kd_tree, + num_neighbors: 3, + num_classes: 2 + ) + + assert model.algorithm == Scholar.Neighbors.KDTree.fit(x_train(), num_neighbors: 3) + assert model.num_classes == 2 + assert model.labels == y_train() + assert model.weights == :uniform + end + end + + describe "predict" do + test "predict with default values" do + model = KNNClassifier.fit(x_train(), y_train(), num_neighbors: 3, num_classes: 2) + labels_pred = KNNClassifier.predict(model, x()) + assert labels_pred == Nx.tensor([1, 1, 0, 1]) + end + + test "predict with k-d tree" do + model = + KNNClassifier.fit(x_train(), y_train(), + algorithm: :kd_tree, + num_neighbors: 3, + num_classes: 2 + ) + + labels_pred = KNNClassifier.predict(model, x()) + assert labels_pred == Nx.tensor([1, 1, 0, 1]) + end + + test "predict with weights set to :distance" do + model = + KNNClassifier.fit(x_train(), y_train(), + num_neighbors: 3, + num_classes: 2, + weights: :distance + ) + + labels_pred = KNNClassifier.predict(model, x()) + assert labels_pred == Nx.tensor([1, 1, 0, 1]) + end + + test "predict with specific metric and weights set to :distance" do + model = + KNNClassifier.fit(x_train(), y_train(), + num_neighbors: 3, + num_classes: 2, + metric: {:minkowski, 1.5}, + weights: :distance + ) + + labels_pred = KNNClassifier.predict(model, x()) + assert labels_pred == Nx.tensor([1, 1, 0, 1]) + end + + test "predict with weights set to :distance and with x that contains sample with zero-distance" do + x = Nx.tensor([[3, 6, 7, 5], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]]) + + model = + KNNClassifier.fit(x_train(), y_train(), + num_neighbors: 3, + num_classes: 2, + weights: :distance + ) + + labels_pred = KNNClassifier.predict(model, x) + assert labels_pred == Nx.tensor([0, 1, 0, 1]) + end + end +end From 410bb99e8ec48056c7169df9913e50cbd4f254d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Sun, 12 May 2024 21:03:40 +0200 Subject: [PATCH 3/6] Update doc --- lib/scholar/neighbors/knn_classifier.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/scholar/neighbors/knn_classifier.ex b/lib/scholar/neighbors/knn_classifier.ex index f56b6e58..455b3b04 100644 --- a/lib/scholar/neighbors/knn_classifier.ex +++ b/lib/scholar/neighbors/knn_classifier.ex @@ -2,7 +2,7 @@ defmodule Scholar.Neighbors.KNNClassifier do @moduledoc """ K-Nearest Neighbors Classifier. - ... + The model classifies the point by looking at its k-nearest neighbors and performing a (weighted) majority voting. """ import Nx.Defn From 6b9ddd05bcefb2870f0e5ed32efe949959594896 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Sun, 12 May 2024 21:12:09 +0200 Subject: [PATCH 4/6] Update doc --- lib/scholar/neighbors/knn_classifier.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/scholar/neighbors/knn_classifier.ex b/lib/scholar/neighbors/knn_classifier.ex index 455b3b04..d5bfb0a7 100644 --- a/lib/scholar/neighbors/knn_classifier.ex +++ b/lib/scholar/neighbors/knn_classifier.ex @@ -2,7 +2,7 @@ defmodule Scholar.Neighbors.KNNClassifier do @moduledoc """ K-Nearest Neighbors Classifier. - The model classifies the point by looking at its k-nearest neighbors and performing a (weighted) majority voting. + Performs classifiction by looking at the k-nearest neighbors of a point and using (weighted) majority voting. """ import Nx.Defn From ea158f196623e8f5cdaaa3cd2af5b01633576e38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Tue, 14 May 2024 15:13:10 +0200 Subject: [PATCH 5/6] Add metric to RandomProjectionForest and LargeVis, more unit-tests, etc --- lib/scholar/neighbors/brute_knn.ex | 2 +- lib/scholar/neighbors/knn_classifier.ex | 138 ++++++++++-------- lib/scholar/neighbors/large_vis.ex | 32 +++- .../neighbors/random_projection_forest.ex | 24 ++- lib/scholar/neighbors/utils.ex | 18 ++- .../scholar/neighbors/knn_classifier_test.exs | 24 +++ 6 files changed, 161 insertions(+), 77 deletions(-) diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex index bfb584af..97ecbaab 100644 --- a/lib/scholar/neighbors/brute_knn.ex +++ b/lib/scholar/neighbors/brute_knn.ex @@ -25,7 +25,7 @@ defmodule Scholar.Neighbors.BruteKNN do type: {:or, [{:custom, Scholar.Options, :metric, []}, {:fun, 2}]}, default: {:minkowski, 2}, doc: ~S""" - The function that measures distance between two points. Possible values: + The function that measures the distance between two points. Possible values: * `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`) we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric. diff --git a/lib/scholar/neighbors/knn_classifier.ex b/lib/scholar/neighbors/knn_classifier.ex index d5bfb0a7..a6a963b8 100644 --- a/lib/scholar/neighbors/knn_classifier.ex +++ b/lib/scholar/neighbors/knn_classifier.ex @@ -2,23 +2,22 @@ defmodule Scholar.Neighbors.KNNClassifier do @moduledoc """ K-Nearest Neighbors Classifier. - Performs classifiction by looking at the k-nearest neighbors of a point and using (weighted) majority voting. + Performs classifiction by computing the (weighted) majority voting among k-nearest neighbors. """ import Nx.Defn + import Scholar.Shared require Nx - @derive {Nx.Container, keep: [:algorithm, :num_classes, :weights], containers: [:labels]} + @derive {Nx.Container, keep: [:num_classes, :weights], containers: [:algorithm, :labels]} defstruct [:algorithm, :num_classes, :weights, :labels] opts = [ algorithm: [ - type: {:or, [:atom, {:tuple, [:atom, :keyword_list]}]}, + type: :atom, default: :brute, doc: """ - k-NN algorithm to be used for finding the nearest neighbors. It can be provided as - an atom or a tuple containing an atom and algorithm specific options. - Possible values for the atom: + Algorithm used to compute the k-nearest neighbors. Possible values: * `:brute` - Brute-force search. See `Scholar.Neighbors.BruteKNN` for more details. @@ -26,26 +25,8 @@ defmodule Scholar.Neighbors.KNNClassifier do * `:random_projection_forest` - Random projection forest. See `Scholar.Neighbors.RandomProjectionForest` for more details. - * Module implementing fit/2 and predict/2. - """ - ], - num_neighbors: [ - required: true, - type: :pos_integer, - doc: "The number of nearest neighbors." - ], - metric: [ - type: {:or, [{:custom, Scholar.Options, :metric, []}, {:fun, 2}]}, - default: {:minkowski, 2}, - doc: """ - The function that measures distance between two points. Possible values: - - * `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`) - we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric. - - * `:cosine` - Cosine metric. - - Keep in mind that different algorithms support different metrics. For more information have a look at the corresponding modules. + * Module implementing `fit(data, opts)` and `predict(model, query)`. predict/2 must return tuple containing indices + of k-nearest neighbors of query points as well as distances between query points and their k-nearest neighbors. """ ], num_classes: [ @@ -76,6 +57,8 @@ defmodule Scholar.Neighbors.KNNClassifier do #{NimbleOptions.docs(@opts_schema)} + Algorithm-specific options (e.g. `:num_neighbors`, `:metric`) should be provided together with the classifier options. + ## Examples iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) @@ -85,31 +68,54 @@ defmodule Scholar.Neighbors.KNNClassifier do Scholar.Neighbors.BruteKNN.fit(x, num_neighbors: 3) iex> model.labels Nx.tensor([0, 0, 0, 1, 1]) + + iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> y = Nx.tensor([0, 0, 0, 1, 1]) + iex> model = Scholar.Neighbors.KNNClassifier.fit(x, y, algorithm: :kd_tree, num_neighbors: 3, metric: {:minkowski, 1}, num_classes: 2) + iex> model.algorithm + Scholar.Neighbors.KDTree.fit(x, num_neighbors: 3, metric: {:minkowski, 1}) + iex> model.labels + Nx.tensor([0, 0, 0, 1, 1]) + + iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> y = Nx.tensor([0, 0, 0, 1, 1]) + iex> key = Nx.Random.key(12) + iex> model = Scholar.Neighbors.KNNClassifier.fit(x, y, algorithm: :random_projection_forest, num_neighbors: 2, num_classes: 2, num_trees: 4, key: key) + iex> model.algorithm + Scholar.Neighbors.RandomProjectionForest.fit(x, num_neighbors: 2, num_trees: 4, key: key) + iex> model.labels + Nx.tensor([0, 0, 0, 1, 1]) """ deftransform fit(x, y, opts) do if Nx.rank(x) != 2 do raise ArgumentError, - "expected x to have shape {num_samples, num_features}, - got tensor with shape: #{inspect(Nx.shape(x))}" + """ + expected x to have shape {num_samples, num_features}, \ + got tensor with shape: #{inspect(Nx.shape(x))} + """ end - if Nx.rank(y) != 1 and Nx.axis_size(x, 0) == Nx.axis_size(y, 0) do + if Nx.rank(y) != 1 do raise ArgumentError, - "expected y to have shape {num_samples}, - got tensor with shape: #{inspect(Nx.shape(y))}" + """ + expected y to have shape {num_samples}, \ + got tensor with shape: #{inspect(Nx.shape(y))} + """ end - opts = NimbleOptions.validate!(opts, @opts_schema) + if Nx.axis_size(x, 0) != Nx.axis_size(y, 0) do + raise ArgumentError, + """ + expected x and y to have the same first dimension, \ + got #{Nx.axis_size(x, 0)} and #{Nx.axis_size(y, 0)} + """ + end - {algorithm_name, algorithm_opts} = - if is_atom(opts[:algorithm]) do - {opts[:algorithm], []} - else - opts[:algorithm] - end + {opts, algorithm_opts} = Keyword.split(opts, [:algorithm, :num_classes, :weights]) + opts = NimbleOptions.validate!(opts, @opts_schema) - knn_module = - case algorithm_name do + algorithm_module = + case opts[:algorithm] do :brute -> Scholar.Neighbors.BruteKNN @@ -119,22 +125,11 @@ defmodule Scholar.Neighbors.KNNClassifier do :random_projection_forest -> Scholar.Neighbors.RandomProjectionForest - knn_module when is_atom(knn_module) -> - knn_module - - _ -> - raise ArgumentError, - """ - not supported - """ + module when is_atom(module) -> + module end - # TODO: Maybe raise an error if :num_neighbors or :metric is already in algorithm_opts? - - algorithm_opts = Keyword.put(algorithm_opts, :num_neighbors, opts[:num_neighbors]) - algorithm_opts = Keyword.put(algorithm_opts, :metric, opts[:metric]) - - algorithm = knn_module.fit(x, algorithm_opts) + algorithm = algorithm_module.fit(x, algorithm_opts) %__MODULE__{ algorithm: algorithm, @@ -156,9 +151,8 @@ defmodule Scholar.Neighbors.KNNClassifier do iex> Scholar.Neighbors.KNNClassifier.predict(model, x_test) Nx.tensor([0, 0, 1]) """ - deftransform predict(model, x) do - knn_module = model.algorithm.__struct__ - {neighbors, distances} = knn_module.predict(model.algorithm, x) + defn predict(model, x) do + {neighbors, distances} = compute_knn(model.algorithm, x) labels_pred = Nx.take(model.labels, neighbors) case model.weights do @@ -167,6 +161,36 @@ defmodule Scholar.Neighbors.KNNClassifier do end end + defn predict_proba(model, x) do + num_samples = Nx.axis_size(x, 0) + {neighbors, distances} = compute_knn(model.algorithm, x) + labels_pred = Nx.take(model.labels, neighbors) + type = Nx.Type.merge(to_float_type(x), {:f, 32}) + proba = Nx.broadcast(Nx.tensor(0.0, type: type), {num_samples, model.num_classes}) + + weights = + case model.weights do + :distance -> check_weights(distances) + :uniform -> Nx.broadcast(1.0, neighbors) + end + + indices = + Nx.stack( + [Nx.iota(Nx.shape(labels_pred), axis: 0), Nx.take(model.labels, labels_pred)], + axis: -1 # TODO: Replace -1 here + ) + |> Nx.flatten(axes: [0, 1]) + + proba = Nx.indexed_add(proba, indices, Nx.flatten(weights)) + normalizer = Nx.sum(proba, axes: [1]) + normalizer = Nx.select(normalizer == 0, 1, normalizer) + proba / Nx.new_axis(normalizer, -1) # TODO: Replace -1 here + end + + deftransformp compute_knn(algorithm, x) do + algorithm.__struct__.predict(algorithm, x) + end + defnp check_weights(weights) do zero_mask = weights == 0 zero_rows = zero_mask |> Nx.any(axes: [1], keep_axes: true) |> Nx.broadcast(weights) diff --git a/lib/scholar/neighbors/large_vis.ex b/lib/scholar/neighbors/large_vis.ex index 8f98fdbf..076a9c9c 100644 --- a/lib/scholar/neighbors/large_vis.ex +++ b/lib/scholar/neighbors/large_vis.ex @@ -14,7 +14,7 @@ defmodule Scholar.Neighbors.LargeVis do import Nx.Defn import Scholar.Shared require Nx - alias Scholar.Neighbors.RandomProjectionForest, as: Forest + alias Scholar.Neighbors.RandomProjectionForest alias Scholar.Neighbors.Utils opts = [ @@ -23,6 +23,11 @@ defmodule Scholar.Neighbors.LargeVis do type: :pos_integer, doc: "The number of neighbors in the graph." ], + metric: [ + type: {:in, [:squared_euclidean, :euclidean]}, + default: :euclidean, + doc: "The function that measures distance between two points." + ], min_leaf_size: [ type: :pos_integer, doc: """ @@ -63,7 +68,7 @@ defmodule Scholar.Neighbors.LargeVis do iex> key = Nx.Random.key(12) iex> tensor = Nx.iota({5, 2}) - iex> {graph, distances} = Scholar.Neighbors.LargeVis.fit(tensor, num_neighbors: 2, min_leaf_size: 2, num_trees: 3, key: key) + iex> {graph, distances} = Scholar.Neighbors.LargeVis.fit(tensor, num_neighbors: 2, metric: :squared_euclidean, min_leaf_size: 2, num_trees: 3, key: key) iex> graph #Nx.Tensor< u32[5][2] @@ -98,6 +103,13 @@ defmodule Scholar.Neighbors.LargeVis do opts = NimbleOptions.validate!(opts, @opts_schema) k = opts[:num_neighbors] + + metric = + case opts[:metric] do + :euclidean -> &Scholar.Metrics.Distance.euclidean/2 + :squared_euclidean -> &Scholar.Metrics.Distance.squared_euclidean/2 + end + min_leaf_size = opts[:min_leaf_size] || max(10, 2 * k) size = Nx.axis_size(tensor, 0) @@ -108,6 +120,7 @@ defmodule Scholar.Neighbors.LargeVis do tensor, key, num_neighbors: k, + metric: metric, min_leaf_size: min_leaf_size, num_trees: num_trees, num_iters: opts[:num_iters] @@ -116,15 +129,15 @@ defmodule Scholar.Neighbors.LargeVis do defnp fit_n(tensor, key, opts) do forest = - Forest.fit(tensor, + RandomProjectionForest.fit(tensor, num_neighbors: opts[:num_neighbors], min_leaf_size: opts[:min_leaf_size], num_trees: opts[:num_trees], key: key ) - {graph, _} = Forest.predict(forest, tensor) - expand(graph, tensor, num_iters: opts[:num_iters]) + {graph, _} = RandomProjectionForest.predict(forest, tensor) + expand(graph, tensor, metric: opts[:metric], num_iters: opts[:num_iters]) end defn expand(graph, tensor, opts) do @@ -140,17 +153,20 @@ defmodule Scholar.Neighbors.LargeVis do {tensor, iter = 0} }, iter < num_iters do - {expansion_iter(graph, tensor), {tensor, iter + 1}} + {expansion_iter(graph, tensor, metric: opts[:metric]), {tensor, iter + 1}} end result end - defnp expansion_iter(graph, tensor) do + defnp expansion_iter(graph, tensor, opts) do {size, k} = Nx.shape(graph) candidate_indices = Nx.take(graph, graph) |> Nx.reshape({size, k * k}) candidate_indices = Nx.concatenate([graph, candidate_indices], axis: 1) - Utils.find_neighbors(tensor, tensor, candidate_indices, num_neighbors: k) + Utils.brute_force_search_with_candidates(tensor, tensor, candidate_indices, + num_neighbors: k, + metric: opts[:metric] + ) end end diff --git a/lib/scholar/neighbors/random_projection_forest.ex b/lib/scholar/neighbors/random_projection_forest.ex index 5c89d173..b610299e 100644 --- a/lib/scholar/neighbors/random_projection_forest.ex +++ b/lib/scholar/neighbors/random_projection_forest.ex @@ -23,7 +23,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do alias Scholar.Neighbors.Utils @derive {Nx.Container, - keep: [:num_neighbors, :depth, :leaf_size, :num_trees], + keep: [:num_neighbors, :metric, :depth, :leaf_size, :num_trees], containers: [:indices, :data, :hyperplanes, :medians]} @enforce_keys [ :num_neighbors, @@ -37,6 +37,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do ] defstruct [ :num_neighbors, + :metric, :depth, :leaf_size, :num_trees, @@ -52,6 +53,11 @@ defmodule Scholar.Neighbors.RandomProjectionForest do type: :pos_integer, doc: "The number of nearest neighbors." ], + metric: [ + type: {:in, [:squared_euclidean, :euclidean]}, + default: :euclidean, + doc: "The function that measures the distance between two points." + ], min_leaf_size: [ type: :pos_integer, doc: "The minumum number of points in the leaf." @@ -107,6 +113,12 @@ defmodule Scholar.Neighbors.RandomProjectionForest do num_neighbors = opts[:num_neighbors] min_leaf_size = opts[:min_leaf_size] + metric = + case opts[:metric] do + :euclidean -> &Scholar.Metrics.Distance.euclidean/2 + :squared_euclidean -> &Scholar.Metrics.Distance.squared_euclidean/2 + end + min_leaf_size = cond do is_nil(min_leaf_size) -> @@ -142,6 +154,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do %__MODULE__{ num_neighbors: num_neighbors, + metric: metric, depth: depth, leaf_size: leaf_size, num_trees: num_trees, @@ -283,7 +296,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do iex> key = Nx.Random.key(12) iex> tensor = Nx.iota({5, 2}) - iex> forest = Scholar.Neighbors.RandomProjectionForest.fit(tensor, num_neighbors: 2, num_trees: 3, key: key) + iex> forest = Scholar.Neighbors.RandomProjectionForest.fit(tensor, num_neighbors: 2, metric: :squared_euclidean, num_trees: 3, key: key) iex> query = Nx.tensor([[3, 4]]) iex> {neighbors, distances} = Scholar.Neighbors.RandomProjectionForest.predict(forest, query) iex> neighbors @@ -323,9 +336,12 @@ defmodule Scholar.Neighbors.RandomProjectionForest do end defnp predict_n(forest, query) do - k = forest.num_neighbors candidate_indices = get_leaves(forest, query) - Utils.find_neighbors(query, forest.data, candidate_indices, num_neighbors: k) + + Utils.brute_force_search_with_candidates(forest.data, query, candidate_indices, + num_neighbors: forest.num_neighbors, + metric: forest.metric + ) end @doc false diff --git a/lib/scholar/neighbors/utils.ex b/lib/scholar/neighbors/utils.ex index f202348b..06acc8c4 100644 --- a/lib/scholar/neighbors/utils.ex +++ b/lib/scholar/neighbors/utils.ex @@ -3,16 +3,20 @@ defmodule Scholar.Neighbors.Utils do import Nx.Defn require Nx - defn find_neighbors(query, data, candidate_indices, opts) do + defn brute_force_search_with_candidates(data, query, candidate_indices, opts) do k = opts[:num_neighbors] + metric = opts[:metric] + dim = Nx.axis_size(data, 1) {size, length} = Nx.shape(candidate_indices) - distances = + x = query |> Nx.new_axis(1) - |> Nx.subtract(Nx.take(data, candidate_indices)) - |> Nx.pow(2) - |> Nx.sum(axes: [2]) + |> Nx.broadcast({size, length, dim}) + |> Nx.vectorize([:query, :candidates]) + + y = Nx.take(data, candidate_indices) |> Nx.vectorize([:query, :candidates]) + distances = metric.(x, y) |> Nx.devectorize() |> Nx.rename(nil) distances = if length > 1 do @@ -55,11 +59,11 @@ defmodule Scholar.Neighbors.Utils do target = Nx.broadcast(Nx.u32(0), {size, length}) samples = Nx.iota({size, length, 1}, axis: 0) - indices = + target_indices = Nx.concatenate([samples, Nx.new_axis(indices, 2)], axis: 2) |> Nx.reshape({size * length, 2}) updates = Nx.iota({size, length}, axis: 1) |> Nx.reshape({size * length}) - Nx.indexed_add(target, indices, updates) + Nx.indexed_add(target, target_indices, updates) end end diff --git a/test/scholar/neighbors/knn_classifier_test.exs b/test/scholar/neighbors/knn_classifier_test.exs index 5d6731b2..19819f4f 100644 --- a/test/scholar/neighbors/knn_classifier_test.exs +++ b/test/scholar/neighbors/knn_classifier_test.exs @@ -49,6 +49,30 @@ defmodule Scholar.Neighbors.KNNClassifierTest do assert model.labels == y_train() assert model.weights == :uniform end + + test "fit with random projection forest" do + key = Nx.Random.key(12) + + model = + KNNClassifier.fit(x_train(), y_train(), + algorithm: :random_projection_forest, + num_neighbors: 3, + num_classes: 2, + num_trees: 4, + key: key + ) + + assert model.algorithm == + Scholar.Neighbors.RandomProjectionForest.fit(x_train(), + num_neighbors: 3, + num_trees: 4, + key: key + ) + + assert model.num_classes == 2 + assert model.labels == y_train() + assert model.weights == :uniform + end end describe "predict" do From 19f1f754d8def506972f20bd9abc6f3b3045f5e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krsto=20Prorokovi=C4=87?= Date: Tue, 14 May 2024 18:10:14 +0200 Subject: [PATCH 6/6] Add predict_proba/2 --- lib/scholar/neighbors/kd_tree.ex | 15 ++- lib/scholar/neighbors/knn_classifier.ex | 30 ++++-- .../neighbors/random_projection_forest.ex | 6 +- .../scholar/neighbors/knn_classifier_test.exs | 99 +++++++++++++++++-- 4 files changed, 130 insertions(+), 20 deletions(-) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index 1228a935..3c3bbead 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -318,11 +318,20 @@ defmodule Scholar.Neighbors.KDTree do """ deftransform predict(tree, data) do if Nx.rank(data) != 2 do - raise ArgumentError, "Input data must be a 2D tensor" + raise ArgumentError, + """ + expected query tensor to have shape {num_queries, num_features}, \ + got tensor with shape: #{inspect(Nx.shape(data))} + """ end - if Nx.axis_size(data, -1) != Nx.axis_size(tree.data, -1) do - raise ArgumentError, "Input data must have the same number of features as the training data" + if Nx.axis_size(tree.data, 1) != Nx.axis_size(data, 1) do + raise ArgumentError, + """ + expected query tensor to have same number of features as tensor used to fit the tree, \ + got #{inspect(Nx.axis_size(data, 1))} \ + and #{inspect(Nx.axis_size(tree.data, 1))} + """ end predict_n(tree, data) diff --git a/lib/scholar/neighbors/knn_classifier.ex b/lib/scholar/neighbors/knn_classifier.ex index a6a963b8..4ddea1ba 100644 --- a/lib/scholar/neighbors/knn_classifier.ex +++ b/lib/scholar/neighbors/knn_classifier.ex @@ -140,15 +140,15 @@ defmodule Scholar.Neighbors.KNNClassifier do end @doc """ - Makes predictions using a k-NN classifier model. + Predicts classes using a k-NN classifier model. ## Examples iex> x_train = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) iex> y_train = Nx.tensor([0, 0, 0, 1, 1]) iex> model = Scholar.Neighbors.KNNClassifier.fit(x_train, y_train, num_neighbors: 3, num_classes: 2) - iex> x_test = Nx.tensor([[1, 3], [4, 2], [3, 6]]) - iex> Scholar.Neighbors.KNNClassifier.predict(model, x_test) + iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]]) + iex> Scholar.Neighbors.KNNClassifier.predict(model, x) Nx.tensor([0, 0, 1]) """ defn predict(model, x) do @@ -161,6 +161,24 @@ defmodule Scholar.Neighbors.KNNClassifier do end end + @doc """ + Predicts class probabilities using a k-NN classifier model. + + ## Examples + + iex> x_train = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]]) + iex> y_train = Nx.tensor([0, 0, 0, 1, 1]) + iex> model = Scholar.Neighbors.KNNClassifier.fit(x_train, y_train, num_neighbors: 3, num_classes: 2) + iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]]) + iex> Scholar.Neighbors.KNNClassifier.predict_proba(model, x) + Nx.tensor( + [ + [1.0, 0.0], + [1.0, 0.0], + [1.0, 0.0] + ] + ) + """ defn predict_proba(model, x) do num_samples = Nx.axis_size(x, 0) {neighbors, distances} = compute_knn(model.algorithm, x) @@ -170,21 +188,21 @@ defmodule Scholar.Neighbors.KNNClassifier do weights = case model.weights do - :distance -> check_weights(distances) :uniform -> Nx.broadcast(1.0, neighbors) + :distance -> check_weights(distances) end indices = Nx.stack( [Nx.iota(Nx.shape(labels_pred), axis: 0), Nx.take(model.labels, labels_pred)], - axis: -1 # TODO: Replace -1 here + axis: 2 ) |> Nx.flatten(axes: [0, 1]) proba = Nx.indexed_add(proba, indices, Nx.flatten(weights)) normalizer = Nx.sum(proba, axes: [1]) normalizer = Nx.select(normalizer == 0, 1, normalizer) - proba / Nx.new_axis(normalizer, -1) # TODO: Replace -1 here + proba / Nx.new_axis(normalizer, 1) end deftransformp compute_knn(algorithm, x) do diff --git a/lib/scholar/neighbors/random_projection_forest.ex b/lib/scholar/neighbors/random_projection_forest.ex index b610299e..6c235b33 100644 --- a/lib/scholar/neighbors/random_projection_forest.ex +++ b/lib/scholar/neighbors/random_projection_forest.ex @@ -326,9 +326,9 @@ defmodule Scholar.Neighbors.RandomProjectionForest do if Nx.axis_size(forest.data, 1) != Nx.axis_size(query, 1) do raise ArgumentError, """ - expected query tensor to have the same dimension as tensor used to grow the forest, \ - got #{inspect(Nx.axis_size(forest.data, 1))} \ - and #{inspect(Nx.axis_size(query, 1))} + expected query tensor to have same number of features as tensor used to grow the forest, \ + got #{inspect(Nx.axis_size(query, 1))} \ + and #{inspect(Nx.axis_size(forest.data, 1))} """ end diff --git a/test/scholar/neighbors/knn_classifier_test.exs b/test/scholar/neighbors/knn_classifier_test.exs index 19819f4f..2c08311d 100644 --- a/test/scholar/neighbors/knn_classifier_test.exs +++ b/test/scholar/neighbors/knn_classifier_test.exs @@ -40,8 +40,8 @@ defmodule Scholar.Neighbors.KNNClassifierTest do model = KNNClassifier.fit(x_train(), y_train(), algorithm: :kd_tree, - num_neighbors: 3, - num_classes: 2 + num_classes: 2, + num_neighbors: 3 ) assert model.algorithm == Scholar.Neighbors.KDTree.fit(x_train(), num_neighbors: 3) @@ -56,8 +56,8 @@ defmodule Scholar.Neighbors.KNNClassifierTest do model = KNNClassifier.fit(x_train(), y_train(), algorithm: :random_projection_forest, - num_neighbors: 3, num_classes: 2, + num_neighbors: 3, num_trees: 4, key: key ) @@ -86,8 +86,8 @@ defmodule Scholar.Neighbors.KNNClassifierTest do model = KNNClassifier.fit(x_train(), y_train(), algorithm: :kd_tree, - num_neighbors: 3, - num_classes: 2 + num_classes: 2, + num_neighbors: 3 ) labels_pred = KNNClassifier.predict(model, x()) @@ -97,8 +97,8 @@ defmodule Scholar.Neighbors.KNNClassifierTest do test "predict with weights set to :distance" do model = KNNClassifier.fit(x_train(), y_train(), - num_neighbors: 3, num_classes: 2, + num_neighbors: 3, weights: :distance ) @@ -109,8 +109,8 @@ defmodule Scholar.Neighbors.KNNClassifierTest do test "predict with specific metric and weights set to :distance" do model = KNNClassifier.fit(x_train(), y_train(), - num_neighbors: 3, num_classes: 2, + num_neighbors: 3, metric: {:minkowski, 1.5}, weights: :distance ) @@ -124,8 +124,8 @@ defmodule Scholar.Neighbors.KNNClassifierTest do model = KNNClassifier.fit(x_train(), y_train(), - num_neighbors: 3, num_classes: 2, + num_neighbors: 3, weights: :distance ) @@ -133,4 +133,87 @@ defmodule Scholar.Neighbors.KNNClassifierTest do assert labels_pred == Nx.tensor([0, 1, 0, 1]) end end + + describe "predict_proba" do + test "predict_proba with default values" do + model = KNNClassifier.fit(x_train(), y_train(), num_classes: 2, num_neighbors: 3) + predictions = KNNClassifier.predict_proba(model, x()) + + assert_all_close( + predictions, + Nx.tensor([ + [0.33333333, 0.66666667], + [0.33333333, 0.66666667], + [0.66666667, 0.33333333], + [0.0, 1.0] + ]) + ) + end + + test "predict_proba with weights set to :distance" do + model = + KNNClassifier.fit(x_train(), y_train(), + num_neighbors: 3, + num_classes: 2, + weights: :distance + ) + + predictions = KNNClassifier.predict_proba(model, x()) + + assert_all_close( + predictions, + Nx.tensor([ + [0.40351151, 0.59648849], + [0.31717204, 0.68282796], + [0.7283494, 0.2716506], + [0.0, 1.0] + ]) + ) + end + + test "predict_proba with weights set to :distance and with specific metric" do + model = + KNNClassifier.fit(x_train(), y_train(), + num_classes: 2, + num_neighbors: 3, + weights: :distance, + metric: {:minkowski, 1.5} + ) + + predictions = KNNClassifier.predict_proba(model, x()) + + assert_all_close( + predictions, + Nx.tensor([ + [0.40381038, 0.59618962], + [0.31457406, 0.68542594], + [0.72993802, 0.27006198], + [0.0, 1.0] + ]) + ) + end + + test "predict_proba with weights set to :distance and with x that contains sample with zero-distance" do + x = Nx.tensor([[3, 6, 7, 5], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]]) + + model = + KNNClassifier.fit(x_train(), y_train(), + num_classes: 2, + num_neighbors: 3, + weights: :distance + ) + + predictions = KNNClassifier.predict_proba(model, x) + + assert_all_close( + predictions, + Nx.tensor([ + [1.0, 0.0], + [0.31717204, 0.68282796], + [0.7283494, 0.2716506], + [0.0, 1.0] + ]) + ) + end + end end