Skip to content

Commit

Permalink
Make nn algorithm configurable (elixir-nx#281)
Browse files Browse the repository at this point in the history
* Make nn algorithm configurable

* Update lib/scholar/manifold/trimap.ex

Co-authored-by: José Valim <[email protected]>

* Make tests passing

* Add :auto option for knn_algo auto detection

---------

Co-authored-by: José Valim <[email protected]>
  • Loading branch information
msluszniak and josevalim authored Jun 14, 2024
1 parent 09c5ac6 commit 433041f
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 20 deletions.
82 changes: 71 additions & 11 deletions lib/scholar/manifold/trimap.ex
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,20 @@ defmodule Scholar.Manifold.Trimap do
doc: ~S"""
Metric used to compute the distances.
"""
],
knn_algorithm: [
type: {:in, [:auto, :nndescent, :large_vis, :brute]},
default: :auto,
doc: ~S"""
Algorithm used to compute the nearest neighbors. Possible values:
* `:nndescent` - Nearest Neighbors Descent. See `Scholar.Neighbors.NNDescent` for more details.
* `:large_vis` - LargeVis algorithm. See `Scholar.Neighbors.LargeVis` for more details.
* `:brute` - Brute force algorithm. See `Scholar.Neighbors.BruteKNN` for more details.
* `:auto` - Automatically selects the algorithm based on the number of points.
"""
]
]

Expand Down Expand Up @@ -290,15 +304,61 @@ defmodule Scholar.Manifold.Trimap do
num_points = Nx.axis_size(inputs, 0)
num_extra = min(num_inliners + 50, num_points)

nndescent =
Scholar.Neighbors.NNDescent.fit(inputs,
num_neighbors: num_extra,
tree_init?: false,
metric: opts[:metric],
tol: 1.0e-5
)

neighbors = nndescent.nearest_neighbors
neighbors =
case opts[:knn_algorithm] do
:brute ->
model =
Scholar.Neighbors.BruteKNN.fit(inputs,
num_neighbors: num_extra,
metric: opts[:metric]
)

{neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs)
neighbors

:nndescent ->
nndescent =
Scholar.Neighbors.NNDescent.fit(inputs,
num_neighbors: num_extra,
tree_init?: false,
metric: opts[:metric],
tol: 1.0e-5,
key: key
)

nndescent.nearest_neighbors

:large_vis ->
{neighbors, _distances} =
Scholar.Neighbors.LargeVis.fit(inputs,
num_neighbors: num_extra,
metric: opts[:metric],
key: key
)

neighbors

:auto ->
if Nx.axis_size(inputs, 0) <= 500 do
model =
Scholar.Neighbors.BruteKNN.fit(inputs,
num_neighbors: num_extra,
metric: opts[:metric]
)

{neighbors, _distances} = Scholar.Neighbors.BruteKNN.predict(model, inputs)
neighbors
else
{neighbors, _distances} =
Scholar.Neighbors.LargeVis.fit(inputs,
num_neighbors: num_extra,
metric: opts[:metric],
key: key
)

neighbors
end
end

neighbors = Nx.concatenate([Nx.iota({num_points, 1}), neighbors], axis: 1)

Expand Down Expand Up @@ -402,9 +462,9 @@ defmodule Scholar.Manifold.Trimap do
## Examples
iex> {inputs, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {30, 5})
iex> Scholar.Manifold.Trimap.embed(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key)
iex> Scholar.Manifold.Trimap.transform(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key, knn_algorithm: :nndescent)
"""
deftransform embed(inputs, opts \\ []) do
deftransform transform(inputs, opts \\ []) do
opts = NimbleOptions.validate!(opts, @opts_schema)
key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end)
{triplets, opts} = Keyword.pop(opts, :triplets, {})
Expand Down
6 changes: 6 additions & 0 deletions lib/scholar/neighbors/brute_knn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ defmodule Scholar.Neighbors.BruteKNN do
* `:cosine` - Cosine metric.
* `:euclidean` - Euclidean metric.
* `:squared_euclidean` - Squared Euclidean metric.
* `:manhattan` - Manhattan metric.
* Anonymous function of arity 2 that takes two rank-2 tensors.
"""
],
Expand Down
8 changes: 8 additions & 0 deletions lib/scholar/neighbors/utils.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ defmodule Scholar.Neighbors.Utils do
{:ok, &Scholar.Metrics.Distance.pairwise_minkowski(&1, &2, p: p)}
end

def pairwise_metric(:euclidean), do: {:ok, &Scholar.Metrics.Distance.pairwise_euclidean/2}

def pairwise_metric(:squared_euclidean),
do: {:ok, &Scholar.Metrics.Distance.pairwise_squared_euclidean/2}

def pairwise_metric(:manhattan),
do: {:ok, &Scholar.Metrics.Distance.pairwise_minkowski(&1, &2, p: 1)}

def pairwise_metric(metric) when is_function(metric, 2), do: {:ok, metric}

def pairwise_metric(metric) do
Expand Down
26 changes: 17 additions & 9 deletions test/scholar/manifold/trimap_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@ defmodule Scholar.Manifold.TrimapTest do
test "non default num_inliers and num_outliers" do
x = Nx.iota({5, 6})
key = Nx.Random.key(42)
res = Trimap.embed(x, num_components: 2, key: key, num_inliers: 3, num_outliers: 1)

res =
Trimap.transform(x,
num_components: 2,
key: key,
num_inliers: 3,
num_outliers: 1
)

expected =
Nx.tensor([
Expand All @@ -26,7 +33,7 @@ defmodule Scholar.Manifold.TrimapTest do
key = Nx.Random.key(42)

res =
Trimap.embed(x,
Trimap.transform(x,
num_components: 2,
key: key,
num_inliers: 3,
Expand All @@ -53,7 +60,7 @@ defmodule Scholar.Manifold.TrimapTest do
key = Nx.Random.key(42)

res =
Trimap.embed(x,
Trimap.transform(x,
num_components: 2,
key: key,
num_inliers: 3,
Expand Down Expand Up @@ -81,7 +88,7 @@ defmodule Scholar.Manifold.TrimapTest do
weights = Nx.tensor([1.0, 1.0, 1.0, 1.0, 1.0])

res =
Trimap.embed(x,
Trimap.transform(x,
num_components: 2,
key: key,
num_inliers: 3,
Expand Down Expand Up @@ -116,7 +123,7 @@ defmodule Scholar.Manifold.TrimapTest do
])

res =
Trimap.embed(x,
Trimap.transform(x,
num_components: 2,
key: key,
num_inliers: 3,
Expand All @@ -141,12 +148,13 @@ defmodule Scholar.Manifold.TrimapTest do
key = Nx.Random.key(42)

res =
Trimap.embed(x,
Trimap.transform(x,
num_components: 2,
key: key,
num_inliers: 3,
num_outliers: 1,
metric: :manhattan
metric: :manhattan,
knn_algorithm: :nndescent
)

expected =
Expand All @@ -170,7 +178,7 @@ defmodule Scholar.Manifold.TrimapTest do
assert_raise ArgumentError,
"Number of points must be greater than 2",
fn ->
Scholar.Manifold.Trimap.embed(x,
Scholar.Manifold.Trimap.transform(x,
num_components: 2,
key: key,
num_inliers: 10,
Expand All @@ -189,7 +197,7 @@ defmodule Scholar.Manifold.TrimapTest do
"Triplets and weights must be either not initialized or have the same
size of axis zero and rank of triplets must be 2 and rank of weights must be 1",
fn ->
Trimap.embed(x,
Trimap.transform(x,
num_components: 2,
key: key,
num_inliers: 3,
Expand Down

0 comments on commit 433041f

Please sign in to comment.