Skip to content

Commit

Permalink
Make nn algorithm configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak committed Jun 10, 2024
1 parent 32a5b56 commit 9d2e199
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 17 deletions.
43 changes: 33 additions & 10 deletions lib/scholar/manifold/trimap.ex
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ defmodule Scholar.Manifold.Trimap do
doc: ~S"""
Metric used to compute the distances.
"""
],
algorithm: [
type: {:in, [:nndescent, :large_vis]},
default: :large_vis,
doc: ~S"""
Algorithm used to compute the nearest neighbors. Possible values:
* `:nndescent` - Nearest Neighbors Descent. See `Scholar.Neighbors.NNDescent` for more details.
* `:large_vis` - LargeVis algorithm. See `Scholar.Neighbors.LaregVis` for more details.
"""
]
]

Expand Down Expand Up @@ -290,15 +300,28 @@ defmodule Scholar.Manifold.Trimap do
num_points = Nx.axis_size(inputs, 0)
num_extra = min(num_inliners + 50, num_points)

nndescent =
Scholar.Neighbors.NNDescent.fit(inputs,
num_neighbors: num_extra,
tree_init?: false,
metric: opts[:metric],
tol: 1.0e-5
)

neighbors = nndescent.nearest_neighbors
neighbors =
case opts[:algorithm] do
:nndescent ->
nndescent =
Scholar.Neighbors.NNDescent.fit(inputs,
num_neighbors: num_extra,
tree_init?: false,
metric: opts[:metric],
tol: 1.0e-5
)

nndescent.nearest_neighbors

:large_vis ->
{neighbors, _distances} =
Scholar.Neighbors.LargeVis.fit(inputs,
num_neighbors: num_extra,
metric: opts[:metric]
)

neighbors
end

neighbors = Nx.concatenate([Nx.iota({num_points, 1}), neighbors], axis: 1)

Expand Down Expand Up @@ -402,7 +425,7 @@ defmodule Scholar.Manifold.Trimap do
## Examples
iex> {inputs, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {30, 5})
iex> Scholar.Manifold.Trimap.embed(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key)
iex> Scholar.Manifold.Trimap.embed(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, algorithm: :nndescent)
"""
deftransform embed(inputs, opts \\ []) do
opts = NimbleOptions.validate!(opts, @opts_schema)
Expand Down
28 changes: 21 additions & 7 deletions test/scholar/manifold/trimap_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,15 @@ defmodule Scholar.Manifold.TrimapTest do
test "non default num_inliers and num_outliers" do
x = Nx.iota({5, 6})
key = Nx.Random.key(42)
res = Trimap.embed(x, num_components: 2, key: key, num_inliers: 3, num_outliers: 1)

res =
Trimap.embed(x,
num_components: 2,
key: key,
num_inliers: 3,
num_outliers: 1,
algorithm: :nndescent
)

expected =
Nx.tensor([
Expand All @@ -33,7 +41,8 @@ defmodule Scholar.Manifold.TrimapTest do
num_outliers: 1,
num_random: 5,
weight_temp: 0.1,
learning_rate: 0.3
learning_rate: 0.3,
algorithm: :nndescent
)

expected =
Expand All @@ -59,7 +68,8 @@ defmodule Scholar.Manifold.TrimapTest do
num_inliers: 3,
num_outliers: 1,
num_iters: 100,
init_embedding_type: 1
init_embedding_type: 1,
algorithm: :nndescent
)

expected =
Expand Down Expand Up @@ -87,7 +97,8 @@ defmodule Scholar.Manifold.TrimapTest do
num_inliers: 3,
num_outliers: 1,
triplets: triplets,
weights: weights
weights: weights,
algorithm: :nndescent
)

expected =
Expand Down Expand Up @@ -121,7 +132,8 @@ defmodule Scholar.Manifold.TrimapTest do
key: key,
num_inliers: 3,
num_outliers: 1,
init_embeddings: init_embeddings
init_embeddings: init_embeddings,
algorithm: :nndescent
)

expected =
Expand All @@ -146,7 +158,8 @@ defmodule Scholar.Manifold.TrimapTest do
key: key,
num_inliers: 3,
num_outliers: 1,
metric: :manhattan
metric: :manhattan,
algorithm: :nndescent
)

expected =
Expand Down Expand Up @@ -195,7 +208,8 @@ defmodule Scholar.Manifold.TrimapTest do
num_inliers: 3,
num_outliers: 1,
triplets: triplets,
weights: weights
weights: weights,
algorithm: :nndescent
)
end
end
Expand Down

0 comments on commit 9d2e199

Please sign in to comment.