From 433041f3b26d7d1b3b27f413b196ae083c94d668 Mon Sep 17 00:00:00 2001 From: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com> Date: Fri, 14 Jun 2024 11:51:02 +0200 Subject: [PATCH] Make nn algorithm configurable (#281) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Make nn algorithm configurable * Update lib/scholar/manifold/trimap.ex Co-authored-by: José Valim * Make tests passing * Add :auto option for knn_algo auto detection --------- Co-authored-by: José Valim --- lib/scholar/manifold/trimap.ex | 82 +++++++++++++++++++++++---- lib/scholar/neighbors/brute_knn.ex | 6 ++ lib/scholar/neighbors/utils.ex | 8 +++ test/scholar/manifold/trimap_test.exs | 26 ++++++--- 4 files changed, 102 insertions(+), 20 deletions(-) diff --git a/lib/scholar/manifold/trimap.ex b/lib/scholar/manifold/trimap.ex index 755c8cde..abbaa425 100644 --- a/lib/scholar/manifold/trimap.ex +++ b/lib/scholar/manifold/trimap.ex @@ -113,6 +113,20 @@ defmodule Scholar.Manifold.Trimap do doc: ~S""" Metric used to compute the distances. """ + ], + knn_algorithm: [ + type: {:in, [:auto, :nndescent, :large_vis, :brute]}, + default: :auto, + doc: ~S""" + Algorithm used to compute the nearest neighbors. Possible values: + * `:nndescent` - Nearest Neighbors Descent. See `Scholar.Neighbors.NNDescent` for more details. + + * `:large_vis` - LargeVis algorithm. See `Scholar.Neighbors.LargeVis` for more details. + + * `:brute` - Brute force algorithm. See `Scholar.Neighbors.BruteKNN` for more details. + + * `:auto` - Automatically selects the algorithm based on the number of points. + """ ] ] @@ -290,15 +304,61 @@ defmodule Scholar.Manifold.Trimap do num_points = Nx.axis_size(inputs, 0) num_extra = min(num_inliners + 50, num_points) - nndescent = - Scholar.Neighbors.NNDescent.fit(inputs, - num_neighbors: num_extra, - tree_init?: false, - metric: opts[:metric], - tol: 1.0e-5 - ) - - neighbors = nndescent.nearest_neighbors + neighbors = + case opts[:knn_algorithm] do + :brute -> + model = + Scholar.Neighbors.BruteKNN.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric] + ) + + {neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs) + neighbors + + :nndescent -> + nndescent = + Scholar.Neighbors.NNDescent.fit(inputs, + num_neighbors: num_extra, + tree_init?: false, + metric: opts[:metric], + tol: 1.0e-5, + key: key + ) + + nndescent.nearest_neighbors + + :large_vis -> + {neighbors, _distances} = + Scholar.Neighbors.LargeVis.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric], + key: key + ) + + neighbors + + :auto -> + if Nx.axis_size(inputs, 0) <= 500 do + model = + Scholar.Neighbors.BruteKNN.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric] + ) + + {neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs) + neighbors + else + {neighbors, _distances} = + Scholar.Neighbors.LargeVis.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric], + key: key + ) + + neighbors + end + end neighbors = Nx.concatenate([Nx.iota({num_points, 1}), neighbors], axis: 1) @@ -402,9 +462,9 @@ defmodule Scholar.Manifold.Trimap do ## Examples iex> {inputs, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {30, 5}) - iex> Scholar.Manifold.Trimap.embed(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key) + iex> Scholar.Manifold.Trimap.transform(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, knn_algorithm: :nndescent) """ - deftransform embed(inputs, opts \\ []) do + deftransform transform(inputs, opts \\ []) do opts = NimbleOptions.validate!(opts, @opts_schema) key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) {triplets, opts} = Keyword.pop(opts, :triplets, {}) diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex index 27afa7d6..131b1d6d 100644 --- a/lib/scholar/neighbors/brute_knn.ex +++ b/lib/scholar/neighbors/brute_knn.ex @@ -32,6 +32,12 @@ defmodule Scholar.Neighbors.BruteKNN do * `:cosine` - Cosine metric. + * `:euclidean` - Euclidean metric. + + * `:squared_euclidean` - Squared Euclidean metric. + + * `:manhattan` - Manhattan metric. + * Anonymous function of arity 2 that takes two rank-2 tensors. """ ], diff --git a/lib/scholar/neighbors/utils.ex b/lib/scholar/neighbors/utils.ex index 9da9ed5f..d111e301 100644 --- a/lib/scholar/neighbors/utils.ex +++ b/lib/scholar/neighbors/utils.ex @@ -22,6 +22,14 @@ defmodule Scholar.Neighbors.Utils do {:ok, &Scholar.Metrics.Distance.pairwise_minkowski(&1, &2, p: p)} end + def pairwise_metric(:euclidean), do: {:ok, &Scholar.Metrics.Distance.pairwise_euclidean/2} + + def pairwise_metric(:squared_euclidean), + do: {:ok, &Scholar.Metrics.Distance.pairwise_squared_euclidean/2} + + def pairwise_metric(:manhattan), + do: {:ok, &Scholar.Metrics.Distance.pairwise_minkowski(&1, &2, p: 1)} + def pairwise_metric(metric) when is_function(metric, 2), do: {:ok, metric} def pairwise_metric(metric) do diff --git a/test/scholar/manifold/trimap_test.exs b/test/scholar/manifold/trimap_test.exs index 75b0e71e..cffe1b92 100644 --- a/test/scholar/manifold/trimap_test.exs +++ b/test/scholar/manifold/trimap_test.exs @@ -7,7 +7,14 @@ defmodule Scholar.Manifold.TrimapTest do test "non default num_inliers and num_outliers" do x = Nx.iota({5, 6}) key = Nx.Random.key(42) - res = Trimap.embed(x, num_components: 2, key: key, num_inliers: 3, num_outliers: 1) + + res = + Trimap.transform(x, + num_components: 2, + key: key, + num_inliers: 3, + num_outliers: 1 + ) expected = Nx.tensor([ @@ -26,7 +33,7 @@ defmodule Scholar.Manifold.TrimapTest do key = Nx.Random.key(42) res = - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, @@ -53,7 +60,7 @@ defmodule Scholar.Manifold.TrimapTest do key = Nx.Random.key(42) res = - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, @@ -81,7 +88,7 @@ defmodule Scholar.Manifold.TrimapTest do weights = Nx.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) res = - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, @@ -116,7 +123,7 @@ defmodule Scholar.Manifold.TrimapTest do ]) res = - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, @@ -141,12 +148,13 @@ defmodule Scholar.Manifold.TrimapTest do key = Nx.Random.key(42) res = - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, num_outliers: 1, - metric: :manhattan + metric: :manhattan, + knn_algorithm: :nndescent ) expected = @@ -170,7 +178,7 @@ defmodule Scholar.Manifold.TrimapTest do assert_raise ArgumentError, "Number of points must be greater than 2", fn -> - Scholar.Manifold.Trimap.embed(x, + Scholar.Manifold.Trimap.transform(x, num_components: 2, key: key, num_inliers: 10, @@ -189,7 +197,7 @@ defmodule Scholar.Manifold.TrimapTest do "Triplets and weights must be either not initialized or have the same size of axis zero and rank of triplets must be 2 and rank of weights must be 1", fn -> - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3,