Skip to content

Commit

Permalink
Update NNDescent (elixir-nx#245)
Browse files Browse the repository at this point in the history
* Add priority queue

* Apply changes to NNDescent

* Apply Suggestions form code review

* Format

* Remove redundant space

* Raise when tree_init and non-euclidean metric

* Update lib/scholar/neighbors/nn_descent.ex

Co-authored-by: Krsto Proroković <[email protected]>

* Update lib/scholar/neighbors/nn_descent.ex

Co-authored-by: José Valim <[email protected]>

* Decrease tolerance

* Ensure that all initial nearest neighbors are initialized and fix tests according to the implemented changes

---------

Co-authored-by: Krsto Proroković <[email protected]>
Co-authored-by: José Valim <[email protected]>
  • Loading branch information
3 people authored Apr 5, 2024
1 parent e037b24 commit 0619dc5
Show file tree
Hide file tree
Showing 4 changed files with 450 additions and 328 deletions.
16 changes: 11 additions & 5 deletions lib/scholar/manifold/trimap.ex
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,14 @@ defmodule Scholar.Manifold.Trimap do
num_extra = min(num_inliners + 50, num_points)

nndescent =
Scholar.Neighbors.NNDescent.fit(inputs, num_neighbors: num_extra, tree_init?: false)
Scholar.Neighbors.NNDescent.fit(inputs,
num_neighbors: num_extra,
tree_init?: false,
metric: opts[:metric],
tol: 1.0e-5
)

neighbors = nndescent.nearest_neighbors
{neighbors, neighbors}

neighbors = Nx.concatenate([Nx.iota({num_points, 1}), neighbors], axis: 1)

Expand Down Expand Up @@ -397,8 +401,8 @@ defmodule Scholar.Manifold.Trimap do
## Examples
iex> {inputs, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {10, 10})
iex> Scholar.Manifold.Trimap.embed(inputs, num_inliers: 2, num_outliers: 1, num_components: 2, key: key)
iex> {inputs, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {30, 5})
iex> Scholar.Manifold.Trimap.embed(inputs, num_components: 2, num_inliers: 3, num_outliers: 1, key: key)
"""
deftransform embed(inputs, opts \\ []) do
opts = NimbleOptions.validate!(opts, @opts_schema)
Expand Down Expand Up @@ -437,6 +441,7 @@ defmodule Scholar.Manifold.Trimap do
{} ->
inputs =
if num_components > @dim_pca do
inputs = inputs - Nx.mean(inputs, axes: [0])
{u, s, vt} = Nx.LinAlg.SVD.svd(inputs, full_matrices: false)
inputs = Nx.dot(u[[.., 0..@dim_pca]] * s[0..@dim_pca], vt[[0..@dim_pca, ..]])

Expand Down Expand Up @@ -486,7 +491,8 @@ defmodule Scholar.Manifold.Trimap do
gain = Nx.broadcast(Nx.tensor(1.0, type: to_float_type(embeddings)), Nx.shape(embeddings))

{embeddings, _} =
while {embeddings, {vel, gain, lr, triplets, weights, i = 0}}, i < 20 do
while {embeddings, {vel, gain, lr, triplets, weights, i = Nx.s64(0)}},
i < opts[:num_iters] do
gamma = if i < @switch_iter, do: @final_momentum, else: @init_momentum
grad = trimap_loss(embeddings + gamma * vel, triplets, weights)

Expand Down
Loading

0 comments on commit 0619dc5

Please sign in to comment.