From 9d2e19986d40be213feffc9a570e271d3b5306db Mon Sep 17 00:00:00 2001 From: Mateusz Date: Mon, 10 Jun 2024 23:04:35 +0200 Subject: [PATCH] Make nn algorithm configurable --- lib/scholar/manifold/trimap.ex | 43 ++++++++++++++++++++------- test/scholar/manifold/trimap_test.exs | 28 ++++++++++++----- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/lib/scholar/manifold/trimap.ex b/lib/scholar/manifold/trimap.ex index 755c8cde..1bf01a59 100644 --- a/lib/scholar/manifold/trimap.ex +++ b/lib/scholar/manifold/trimap.ex @@ -113,6 +113,16 @@ defmodule Scholar.Manifold.Trimap do doc: ~S""" Metric used to compute the distances. """ + ], + algorithm: [ + type: {:in, [:nndescent, :large_vis]}, + default: :large_vis, + doc: ~S""" + Algorithm used to compute the nearest neighbors. Possible values: + * `:nndescent` - Nearest Neighbors Descent. See `Scholar.Neighbors.NNDescent` for more details. + + * `:large_vis` - LargeVis algorithm. See `Scholar.Neighbors.LaregVis` for more details. + """ ] ] @@ -290,15 +300,28 @@ defmodule Scholar.Manifold.Trimap do num_points = Nx.axis_size(inputs, 0) num_extra = min(num_inliners + 50, num_points) - nndescent = - Scholar.Neighbors.NNDescent.fit(inputs, - num_neighbors: num_extra, - tree_init?: false, - metric: opts[:metric], - tol: 1.0e-5 - ) - - neighbors = nndescent.nearest_neighbors + neighbors = + case opts[:algorithm] do + :nndescent -> + nndescent = + Scholar.Neighbors.NNDescent.fit(inputs, + num_neighbors: num_extra, + tree_init?: false, + metric: opts[:metric], + tol: 1.0e-5 + ) + + nndescent.nearest_neighbors + + :large_vis -> + {neighbors, _distances} = + Scholar.Neighbors.LargeVis.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric] + ) + + neighbors + end neighbors = Nx.concatenate([Nx.iota({num_points, 1}), neighbors], axis: 1) @@ -402,7 +425,7 @@ defmodule Scholar.Manifold.Trimap do ## Examples iex> {inputs, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {30, 5}) - iex> Scholar.Manifold.Trimap.embed(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key) + iex> Scholar.Manifold.Trimap.embed(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, algorithm: :nndescent) """ deftransform embed(inputs, opts \\ []) do opts = NimbleOptions.validate!(opts, @opts_schema) diff --git a/test/scholar/manifold/trimap_test.exs b/test/scholar/manifold/trimap_test.exs index 75b0e71e..971b384d 100644 --- a/test/scholar/manifold/trimap_test.exs +++ b/test/scholar/manifold/trimap_test.exs @@ -7,7 +7,15 @@ defmodule Scholar.Manifold.TrimapTest do test "non default num_inliers and num_outliers" do x = Nx.iota({5, 6}) key = Nx.Random.key(42) - res = Trimap.embed(x, num_components: 2, key: key, num_inliers: 3, num_outliers: 1) + + res = + Trimap.embed(x, + num_components: 2, + key: key, + num_inliers: 3, + num_outliers: 1, + algorithm: :nndescent + ) expected = Nx.tensor([ @@ -33,7 +41,8 @@ defmodule Scholar.Manifold.TrimapTest do num_outliers: 1, num_random: 5, weight_temp: 0.1, - learning_rate: 0.3 + learning_rate: 0.3, + algorithm: :nndescent ) expected = @@ -59,7 +68,8 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, num_iters: 100, - init_embedding_type: 1 + init_embedding_type: 1, + algorithm: :nndescent ) expected = @@ -87,7 +97,8 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, triplets: triplets, - weights: weights + weights: weights, + algorithm: :nndescent ) expected = @@ -121,7 +132,8 @@ defmodule Scholar.Manifold.TrimapTest do key: key, num_inliers: 3, num_outliers: 1, - init_embeddings: init_embeddings + init_embeddings: init_embeddings, + algorithm: :nndescent ) expected = @@ -146,7 +158,8 @@ defmodule Scholar.Manifold.TrimapTest do key: key, num_inliers: 3, num_outliers: 1, - metric: :manhattan + metric: :manhattan, + algorithm: :nndescent ) expected = @@ -195,7 +208,8 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, triplets: triplets, - weights: weights + weights: weights, + algorithm: :nndescent ) end end