From 9d2e19986d40be213feffc9a570e271d3b5306db Mon Sep 17 00:00:00 2001 From: Mateusz Date: Mon, 10 Jun 2024 23:04:35 +0200 Subject: [PATCH 1/4] Make nn algorithm configurable --- lib/scholar/manifold/trimap.ex | 43 ++++++++++++++++++++------- test/scholar/manifold/trimap_test.exs | 28 ++++++++++++----- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/lib/scholar/manifold/trimap.ex b/lib/scholar/manifold/trimap.ex index 755c8cde..1bf01a59 100644 --- a/lib/scholar/manifold/trimap.ex +++ b/lib/scholar/manifold/trimap.ex @@ -113,6 +113,16 @@ defmodule Scholar.Manifold.Trimap do doc: ~S""" Metric used to compute the distances. """ + ], + algorithm: [ + type: {:in, [:nndescent, :large_vis]}, + default: :large_vis, + doc: ~S""" + Algorithm used to compute the nearest neighbors. Possible values: + * `:nndescent` - Nearest Neighbors Descent. See `Scholar.Neighbors.NNDescent` for more details. + + * `:large_vis` - LargeVis algorithm. See `Scholar.Neighbors.LaregVis` for more details. + """ ] ] @@ -290,15 +300,28 @@ defmodule Scholar.Manifold.Trimap do num_points = Nx.axis_size(inputs, 0) num_extra = min(num_inliners + 50, num_points) - nndescent = - Scholar.Neighbors.NNDescent.fit(inputs, - num_neighbors: num_extra, - tree_init?: false, - metric: opts[:metric], - tol: 1.0e-5 - ) - - neighbors = nndescent.nearest_neighbors + neighbors = + case opts[:algorithm] do + :nndescent -> + nndescent = + Scholar.Neighbors.NNDescent.fit(inputs, + num_neighbors: num_extra, + tree_init?: false, + metric: opts[:metric], + tol: 1.0e-5 + ) + + nndescent.nearest_neighbors + + :large_vis -> + {neighbors, _distances} = + Scholar.Neighbors.LargeVis.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric] + ) + + neighbors + end neighbors = Nx.concatenate([Nx.iota({num_points, 1}), neighbors], axis: 1) @@ -402,7 +425,7 @@ defmodule Scholar.Manifold.Trimap do ## Examples iex> {inputs, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {30, 5}) - iex> Scholar.Manifold.Trimap.embed(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key) + iex> Scholar.Manifold.Trimap.embed(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, algorithm: :nndescent) """ deftransform embed(inputs, opts \\ []) do opts = NimbleOptions.validate!(opts, @opts_schema) diff --git a/test/scholar/manifold/trimap_test.exs b/test/scholar/manifold/trimap_test.exs index 75b0e71e..971b384d 100644 --- a/test/scholar/manifold/trimap_test.exs +++ b/test/scholar/manifold/trimap_test.exs @@ -7,7 +7,15 @@ defmodule Scholar.Manifold.TrimapTest do test "non default num_inliers and num_outliers" do x = Nx.iota({5, 6}) key = Nx.Random.key(42) - res = Trimap.embed(x, num_components: 2, key: key, num_inliers: 3, num_outliers: 1) + + res = + Trimap.embed(x, + num_components: 2, + key: key, + num_inliers: 3, + num_outliers: 1, + algorithm: :nndescent + ) expected = Nx.tensor([ @@ -33,7 +41,8 @@ defmodule Scholar.Manifold.TrimapTest do num_outliers: 1, num_random: 5, weight_temp: 0.1, - learning_rate: 0.3 + learning_rate: 0.3, + algorithm: :nndescent ) expected = @@ -59,7 +68,8 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, num_iters: 100, - init_embedding_type: 1 + init_embedding_type: 1, + algorithm: :nndescent ) expected = @@ -87,7 +97,8 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, triplets: triplets, - weights: weights + weights: weights, + algorithm: :nndescent ) expected = @@ -121,7 +132,8 @@ defmodule Scholar.Manifold.TrimapTest do key: key, num_inliers: 3, num_outliers: 1, - init_embeddings: init_embeddings + init_embeddings: init_embeddings, + algorithm: :nndescent ) expected = @@ -146,7 +158,8 @@ defmodule Scholar.Manifold.TrimapTest do key: key, num_inliers: 3, num_outliers: 1, - metric: :manhattan + metric: :manhattan, + algorithm: :nndescent ) expected = @@ -195,7 +208,8 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, triplets: triplets, - weights: weights + weights: weights, + algorithm: :nndescent ) end end From b2ff6ddb760fda6cbaaadf7b7e10a5052cbc3c55 Mon Sep 17 00:00:00 2001 From: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com> Date: Tue, 11 Jun 2024 13:14:25 +0200 Subject: [PATCH 2/4] Update lib/scholar/manifold/trimap.ex MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Valim --- lib/scholar/manifold/trimap.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/scholar/manifold/trimap.ex b/lib/scholar/manifold/trimap.ex index 1bf01a59..87f93756 100644 --- a/lib/scholar/manifold/trimap.ex +++ b/lib/scholar/manifold/trimap.ex @@ -121,7 +121,7 @@ defmodule Scholar.Manifold.Trimap do Algorithm used to compute the nearest neighbors. Possible values: * `:nndescent` - Nearest Neighbors Descent. See `Scholar.Neighbors.NNDescent` for more details. - * `:large_vis` - LargeVis algorithm. See `Scholar.Neighbors.LaregVis` for more details. + * `:large_vis` - LargeVis algorithm. See `Scholar.Neighbors.LargeVis` for more details. """ ] ] From f3640a678ce5d727f8f62f1bf38f69142b758522 Mon Sep 17 00:00:00 2001 From: Mateusz Date: Wed, 12 Jun 2024 20:18:34 +0200 Subject: [PATCH 3/4] Make tests passing --- lib/scholar/manifold/trimap.ex | 55 ++++++++++++++++----------- lib/scholar/neighbors/brute_knn.ex | 6 +++ lib/scholar/neighbors/utils.ex | 8 ++++ test/scholar/manifold/trimap_test.exs | 16 ++++---- 4 files changed, 55 insertions(+), 30 deletions(-) diff --git a/lib/scholar/manifold/trimap.ex b/lib/scholar/manifold/trimap.ex index 1bf01a59..3096a92d 100644 --- a/lib/scholar/manifold/trimap.ex +++ b/lib/scholar/manifold/trimap.ex @@ -301,26 +301,37 @@ defmodule Scholar.Manifold.Trimap do num_extra = min(num_inliners + 50, num_points) neighbors = - case opts[:algorithm] do - :nndescent -> - nndescent = - Scholar.Neighbors.NNDescent.fit(inputs, - num_neighbors: num_extra, - tree_init?: false, - metric: opts[:metric], - tol: 1.0e-5 - ) - - nndescent.nearest_neighbors - - :large_vis -> - {neighbors, _distances} = - Scholar.Neighbors.LargeVis.fit(inputs, - num_neighbors: num_extra, - metric: opts[:metric] - ) - - neighbors + if Nx.axis_size(inputs, 0) <= 500 do + model = + Scholar.Neighbors.BruteKNN.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric] + ) + + {neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs) + neighbors + else + case opts[:algorithm] do + :nndescent -> + nndescent = + Scholar.Neighbors.NNDescent.fit(inputs, + num_neighbors: num_extra, + tree_init?: false, + metric: opts[:metric], + tol: 1.0e-5 + ) + + nndescent.nearest_neighbors + + :large_vis -> + {neighbors, _distances} = + Scholar.Neighbors.LargeVis.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric] + ) + + neighbors + end end neighbors = Nx.concatenate([Nx.iota({num_points, 1}), neighbors], axis: 1) @@ -425,9 +436,9 @@ defmodule Scholar.Manifold.Trimap do ## Examples iex> {inputs, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {30, 5}) - iex> Scholar.Manifold.Trimap.embed(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, algorithm: :nndescent) + iex> Scholar.Manifold.Trimap.transform(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, algorithm: :nndescent) """ - deftransform embed(inputs, opts \\ []) do + deftransform transform(inputs, opts \\ []) do opts = NimbleOptions.validate!(opts, @opts_schema) key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) {triplets, opts} = Keyword.pop(opts, :triplets, {}) diff --git a/lib/scholar/neighbors/brute_knn.ex b/lib/scholar/neighbors/brute_knn.ex index 27afa7d6..131b1d6d 100644 --- a/lib/scholar/neighbors/brute_knn.ex +++ b/lib/scholar/neighbors/brute_knn.ex @@ -32,6 +32,12 @@ defmodule Scholar.Neighbors.BruteKNN do * `:cosine` - Cosine metric. + * `:euclidean` - Euclidean metric. + + * `:squared_euclidean` - Squared Euclidean metric. + + * `:manhattan` - Manhattan metric. + * Anonymous function of arity 2 that takes two rank-2 tensors. """ ], diff --git a/lib/scholar/neighbors/utils.ex b/lib/scholar/neighbors/utils.ex index 9da9ed5f..d111e301 100644 --- a/lib/scholar/neighbors/utils.ex +++ b/lib/scholar/neighbors/utils.ex @@ -22,6 +22,14 @@ defmodule Scholar.Neighbors.Utils do {:ok, &Scholar.Metrics.Distance.pairwise_minkowski(&1, &2, p: p)} end + def pairwise_metric(:euclidean), do: {:ok, &Scholar.Metrics.Distance.pairwise_euclidean/2} + + def pairwise_metric(:squared_euclidean), + do: {:ok, &Scholar.Metrics.Distance.pairwise_squared_euclidean/2} + + def pairwise_metric(:manhattan), + do: {:ok, &Scholar.Metrics.Distance.pairwise_minkowski(&1, &2, p: 1)} + def pairwise_metric(metric) when is_function(metric, 2), do: {:ok, metric} def pairwise_metric(metric) do diff --git a/test/scholar/manifold/trimap_test.exs b/test/scholar/manifold/trimap_test.exs index 971b384d..ef62989a 100644 --- a/test/scholar/manifold/trimap_test.exs +++ b/test/scholar/manifold/trimap_test.exs @@ -9,7 +9,7 @@ defmodule Scholar.Manifold.TrimapTest do key = Nx.Random.key(42) res = - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, @@ -34,7 +34,7 @@ defmodule Scholar.Manifold.TrimapTest do key = Nx.Random.key(42) res = - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, @@ -62,7 +62,7 @@ defmodule Scholar.Manifold.TrimapTest do key = Nx.Random.key(42) res = - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, @@ -91,7 +91,7 @@ defmodule Scholar.Manifold.TrimapTest do weights = Nx.tensor([1.0, 1.0, 1.0, 1.0, 1.0]) res = - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, @@ -127,7 +127,7 @@ defmodule Scholar.Manifold.TrimapTest do ]) res = - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, @@ -153,7 +153,7 @@ defmodule Scholar.Manifold.TrimapTest do key = Nx.Random.key(42) res = - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, @@ -183,7 +183,7 @@ defmodule Scholar.Manifold.TrimapTest do assert_raise ArgumentError, "Number of points must be greater than 2", fn -> - Scholar.Manifold.Trimap.embed(x, + Scholar.Manifold.Trimap.transform(x, num_components: 2, key: key, num_inliers: 10, @@ -202,7 +202,7 @@ defmodule Scholar.Manifold.TrimapTest do "Triplets and weights must be either not initialized or have the same size of axis zero and rank of triplets must be 2 and rank of weights must be 1", fn -> - Trimap.embed(x, + Trimap.transform(x, num_components: 2, key: key, num_inliers: 3, From 9350e14358b920ed0ed1fb97f02b635ecc28f697 Mon Sep 17 00:00:00 2001 From: Mateusz Date: Thu, 13 Jun 2024 23:53:29 +0200 Subject: [PATCH 4/4] Add :auto option for knn_algo auto detection --- lib/scholar/manifold/trimap.ex | 78 ++++++++++++++++++--------- test/scholar/manifold/trimap_test.exs | 20 +++---- 2 files changed, 59 insertions(+), 39 deletions(-) diff --git a/lib/scholar/manifold/trimap.ex b/lib/scholar/manifold/trimap.ex index 32792bfa..abbaa425 100644 --- a/lib/scholar/manifold/trimap.ex +++ b/lib/scholar/manifold/trimap.ex @@ -114,14 +114,18 @@ defmodule Scholar.Manifold.Trimap do Metric used to compute the distances. """ ], - algorithm: [ - type: {:in, [:nndescent, :large_vis]}, - default: :large_vis, + knn_algorithm: [ + type: {:in, [:auto, :nndescent, :large_vis, :brute]}, + default: :auto, doc: ~S""" Algorithm used to compute the nearest neighbors. Possible values: * `:nndescent` - Nearest Neighbors Descent. See `Scholar.Neighbors.NNDescent` for more details. * `:large_vis` - LargeVis algorithm. See `Scholar.Neighbors.LargeVis` for more details. + + * `:brute` - Brute force algorithm. See `Scholar.Neighbors.BruteKNN` for more details. + + * `:auto` - Automatically selects the algorithm based on the number of points. """ ] ] @@ -301,37 +305,59 @@ defmodule Scholar.Manifold.Trimap do num_extra = min(num_inliners + 50, num_points) neighbors = - if Nx.axis_size(inputs, 0) <= 500 do - model = - Scholar.Neighbors.BruteKNN.fit(inputs, - num_neighbors: num_extra, - metric: opts[:metric] - ) - - {neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs) - neighbors - else - case opts[:algorithm] do - :nndescent -> - nndescent = - Scholar.Neighbors.NNDescent.fit(inputs, + case opts[:knn_algorithm] do + :brute -> + model = + Scholar.Neighbors.BruteKNN.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric] + ) + + {neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs) + neighbors + + :nndescent -> + nndescent = + Scholar.Neighbors.NNDescent.fit(inputs, + num_neighbors: num_extra, + tree_init?: false, + metric: opts[:metric], + tol: 1.0e-5, + key: key + ) + + nndescent.nearest_neighbors + + :large_vis -> + {neighbors, _distances} = + Scholar.Neighbors.LargeVis.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric], + key: key + ) + + neighbors + + :auto -> + if Nx.axis_size(inputs, 0) <= 500 do + model = + Scholar.Neighbors.BruteKNN.fit(inputs, num_neighbors: num_extra, - tree_init?: false, - metric: opts[:metric], - tol: 1.0e-5 + metric: opts[:metric] ) - nndescent.nearest_neighbors - - :large_vis -> + {neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs) + neighbors + else {neighbors, _distances} = Scholar.Neighbors.LargeVis.fit(inputs, num_neighbors: num_extra, - metric: opts[:metric] + metric: opts[:metric], + key: key ) neighbors - end + end end neighbors = Nx.concatenate([Nx.iota({num_points, 1}), neighbors], axis: 1) @@ -436,7 +462,7 @@ defmodule Scholar.Manifold.Trimap do ## Examples iex> {inputs, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {30, 5}) - iex> Scholar.Manifold.Trimap.transform(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, algorithm: :nndescent) + iex> Scholar.Manifold.Trimap.transform(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, knn_algorithm: :nndescent) """ deftransform transform(inputs, opts \\ []) do opts = NimbleOptions.validate!(opts, @opts_schema) diff --git a/test/scholar/manifold/trimap_test.exs b/test/scholar/manifold/trimap_test.exs index ef62989a..cffe1b92 100644 --- a/test/scholar/manifold/trimap_test.exs +++ b/test/scholar/manifold/trimap_test.exs @@ -13,8 +13,7 @@ defmodule Scholar.Manifold.TrimapTest do num_components: 2, key: key, num_inliers: 3, - num_outliers: 1, - algorithm: :nndescent + num_outliers: 1 ) expected = @@ -41,8 +40,7 @@ defmodule Scholar.Manifold.TrimapTest do num_outliers: 1, num_random: 5, weight_temp: 0.1, - learning_rate: 0.3, - algorithm: :nndescent + learning_rate: 0.3 ) expected = @@ -68,8 +66,7 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, num_iters: 100, - init_embedding_type: 1, - algorithm: :nndescent + init_embedding_type: 1 ) expected = @@ -97,8 +94,7 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, triplets: triplets, - weights: weights, - algorithm: :nndescent + weights: weights ) expected = @@ -132,8 +128,7 @@ defmodule Scholar.Manifold.TrimapTest do key: key, num_inliers: 3, num_outliers: 1, - init_embeddings: init_embeddings, - algorithm: :nndescent + init_embeddings: init_embeddings ) expected = @@ -159,7 +154,7 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, metric: :manhattan, - algorithm: :nndescent + knn_algorithm: :nndescent ) expected = @@ -208,8 +203,7 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, triplets: triplets, - weights: weights, - algorithm: :nndescent + weights: weights ) end end