diff --git a/lib/scholar/manifold/trimap.ex b/lib/scholar/manifold/trimap.ex index 32792bfa..abbaa425 100644 --- a/lib/scholar/manifold/trimap.ex +++ b/lib/scholar/manifold/trimap.ex @@ -114,14 +114,18 @@ defmodule Scholar.Manifold.Trimap do Metric used to compute the distances. """ ], - algorithm: [ - type: {:in, [:nndescent, :large_vis]}, - default: :large_vis, + knn_algorithm: [ + type: {:in, [:auto, :nndescent, :large_vis, :brute]}, + default: :auto, doc: ~S""" Algorithm used to compute the nearest neighbors. Possible values: * `:nndescent` - Nearest Neighbors Descent. See `Scholar.Neighbors.NNDescent` for more details. * `:large_vis` - LargeVis algorithm. See `Scholar.Neighbors.LargeVis` for more details. + + * `:brute` - Brute force algorithm. See `Scholar.Neighbors.BruteKNN` for more details. + + * `:auto` - Automatically selects the algorithm based on the number of points. """ ] ] @@ -301,37 +305,59 @@ defmodule Scholar.Manifold.Trimap do num_extra = min(num_inliners + 50, num_points) neighbors = - if Nx.axis_size(inputs, 0) <= 500 do - model = - Scholar.Neighbors.BruteKNN.fit(inputs, - num_neighbors: num_extra, - metric: opts[:metric] - ) - - {neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs) - neighbors - else - case opts[:algorithm] do - :nndescent -> - nndescent = - Scholar.Neighbors.NNDescent.fit(inputs, + case opts[:knn_algorithm] do + :brute -> + model = + Scholar.Neighbors.BruteKNN.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric] + ) + + {neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs) + neighbors + + :nndescent -> + nndescent = + Scholar.Neighbors.NNDescent.fit(inputs, + num_neighbors: num_extra, + tree_init?: false, + metric: opts[:metric], + tol: 1.0e-5, + key: key + ) + + nndescent.nearest_neighbors + + :large_vis -> + {neighbors, _distances} = + Scholar.Neighbors.LargeVis.fit(inputs, + num_neighbors: num_extra, + metric: opts[:metric], + key: key + ) + + neighbors + + :auto -> + if Nx.axis_size(inputs, 0) <= 500 do + model = + Scholar.Neighbors.BruteKNN.fit(inputs, num_neighbors: num_extra, - tree_init?: false, - metric: opts[:metric], - tol: 1.0e-5 + metric: opts[:metric] ) - nndescent.nearest_neighbors - - :large_vis -> + {neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs) + neighbors + else {neighbors, _distances} = Scholar.Neighbors.LargeVis.fit(inputs, num_neighbors: num_extra, - metric: opts[:metric] + metric: opts[:metric], + key: key ) neighbors - end + end end neighbors = Nx.concatenate([Nx.iota({num_points, 1}), neighbors], axis: 1) @@ -436,7 +462,7 @@ defmodule Scholar.Manifold.Trimap do ## Examples iex> {inputs, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {30, 5}) - iex> Scholar.Manifold.Trimap.transform(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, algorithm: :nndescent) + iex> Scholar.Manifold.Trimap.transform(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, knn_algorithm: :nndescent) """ deftransform transform(inputs, opts \\ []) do opts = NimbleOptions.validate!(opts, @opts_schema) diff --git a/test/scholar/manifold/trimap_test.exs b/test/scholar/manifold/trimap_test.exs index ef62989a..cffe1b92 100644 --- a/test/scholar/manifold/trimap_test.exs +++ b/test/scholar/manifold/trimap_test.exs @@ -13,8 +13,7 @@ defmodule Scholar.Manifold.TrimapTest do num_components: 2, key: key, num_inliers: 3, - num_outliers: 1, - algorithm: :nndescent + num_outliers: 1 ) expected = @@ -41,8 +40,7 @@ defmodule Scholar.Manifold.TrimapTest do num_outliers: 1, num_random: 5, weight_temp: 0.1, - learning_rate: 0.3, - algorithm: :nndescent + learning_rate: 0.3 ) expected = @@ -68,8 +66,7 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, num_iters: 100, - init_embedding_type: 1, - algorithm: :nndescent + init_embedding_type: 1 ) expected = @@ -97,8 +94,7 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, triplets: triplets, - weights: weights, - algorithm: :nndescent + weights: weights ) expected = @@ -132,8 +128,7 @@ defmodule Scholar.Manifold.TrimapTest do key: key, num_inliers: 3, num_outliers: 1, - init_embeddings: init_embeddings, - algorithm: :nndescent + init_embeddings: init_embeddings ) expected = @@ -159,7 +154,7 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, metric: :manhattan, - algorithm: :nndescent + knn_algorithm: :nndescent ) expected = @@ -208,8 +203,7 @@ defmodule Scholar.Manifold.TrimapTest do num_inliers: 3, num_outliers: 1, triplets: triplets, - weights: weights, - algorithm: :nndescent + weights: weights ) end end