Skip to content

Commit

Permalink
Add :auto option for knn_algo auto detection
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak committed Jun 13, 2024
1 parent e98d0ac commit 9350e14
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 39 deletions.
78 changes: 52 additions & 26 deletions lib/scholar/manifold/trimap.ex
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,18 @@ defmodule Scholar.Manifold.Trimap do
Metric used to compute the distances.
"""
],
algorithm: [
type: {:in, [:nndescent, :large_vis]},
default: :large_vis,
knn_algorithm: [
type: {:in, [:auto, :nndescent, :large_vis, :brute]},
default: :auto,
doc: ~S"""
Algorithm used to compute the nearest neighbors. Possible values:
* `:nndescent` - Nearest Neighbors Descent. See `Scholar.Neighbors.NNDescent` for more details.
* `:large_vis` - LargeVis algorithm. See `Scholar.Neighbors.LargeVis` for more details.
* `:brute` - Brute force algorithm. See `Scholar.Neighbors.BruteKNN` for more details.
* `:auto` - Automatically selects the algorithm based on the number of points.
"""
]
]
Expand Down Expand Up @@ -301,37 +305,59 @@ defmodule Scholar.Manifold.Trimap do
num_extra = min(num_inliners + 50, num_points)

neighbors =
if Nx.axis_size(inputs, 0) <= 500 do
model =
Scholar.Neighbors.BruteKNN.fit(inputs,
num_neighbors: num_extra,
metric: opts[:metric]
)

{neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs)
neighbors
else
case opts[:algorithm] do
:nndescent ->
nndescent =
Scholar.Neighbors.NNDescent.fit(inputs,
case opts[:knn_algorithm] do
:brute ->
model =
Scholar.Neighbors.BruteKNN.fit(inputs,
num_neighbors: num_extra,
metric: opts[:metric]
)

{neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs)
neighbors

:nndescent ->
nndescent =
Scholar.Neighbors.NNDescent.fit(inputs,
num_neighbors: num_extra,
tree_init?: false,
metric: opts[:metric],
tol: 1.0e-5,
key: key
)

nndescent.nearest_neighbors

:large_vis ->
{neighbors, _distances} =
Scholar.Neighbors.LargeVis.fit(inputs,
num_neighbors: num_extra,
metric: opts[:metric],
key: key
)

neighbors

:auto ->
if Nx.axis_size(inputs, 0) <= 500 do
model =
Scholar.Neighbors.BruteKNN.fit(inputs,
num_neighbors: num_extra,
tree_init?: false,
metric: opts[:metric],
tol: 1.0e-5
metric: opts[:metric]
)

nndescent.nearest_neighbors

:large_vis ->
{neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs)
neighbors
else
{neighbors, _distances} =
Scholar.Neighbors.LargeVis.fit(inputs,
num_neighbors: num_extra,
metric: opts[:metric]
metric: opts[:metric],
key: key
)

neighbors
end
end
end

neighbors = Nx.concatenate([Nx.iota({num_points, 1}), neighbors], axis: 1)
Expand Down Expand Up @@ -436,7 +462,7 @@ defmodule Scholar.Manifold.Trimap do
## Examples
iex> {inputs, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {30, 5})
iex> Scholar.Manifold.Trimap.transform(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, algorithm: :nndescent)
iex> Scholar.Manifold.Trimap.transform(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, knn_algorithm: :nndescent)
"""
deftransform transform(inputs, opts \\ []) do
opts = NimbleOptions.validate!(opts, @opts_schema)
Expand Down
20 changes: 7 additions & 13 deletions test/scholar/manifold/trimap_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ defmodule Scholar.Manifold.TrimapTest do
num_components: 2,
key: key,
num_inliers: 3,
num_outliers: 1,
algorithm: :nndescent
num_outliers: 1
)

expected =
Expand All @@ -41,8 +40,7 @@ defmodule Scholar.Manifold.TrimapTest do
num_outliers: 1,
num_random: 5,
weight_temp: 0.1,
learning_rate: 0.3,
algorithm: :nndescent
learning_rate: 0.3
)

expected =
Expand All @@ -68,8 +66,7 @@ defmodule Scholar.Manifold.TrimapTest do
num_inliers: 3,
num_outliers: 1,
num_iters: 100,
init_embedding_type: 1,
algorithm: :nndescent
init_embedding_type: 1
)

expected =
Expand Down Expand Up @@ -97,8 +94,7 @@ defmodule Scholar.Manifold.TrimapTest do
num_inliers: 3,
num_outliers: 1,
triplets: triplets,
weights: weights,
algorithm: :nndescent
weights: weights
)

expected =
Expand Down Expand Up @@ -132,8 +128,7 @@ defmodule Scholar.Manifold.TrimapTest do
key: key,
num_inliers: 3,
num_outliers: 1,
init_embeddings: init_embeddings,
algorithm: :nndescent
init_embeddings: init_embeddings
)

expected =
Expand All @@ -159,7 +154,7 @@ defmodule Scholar.Manifold.TrimapTest do
num_inliers: 3,
num_outliers: 1,
metric: :manhattan,
algorithm: :nndescent
knn_algorithm: :nndescent
)

expected =
Expand Down Expand Up @@ -208,8 +203,7 @@ defmodule Scholar.Manifold.TrimapTest do
num_inliers: 3,
num_outliers: 1,
triplets: triplets,
weights: weights,
algorithm: :nndescent
weights: weights
)
end
end
Expand Down

0 comments on commit 9350e14

Please sign in to comment.