Skip to content

Commit

Permalink
LargeVis (elixir-nx#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
krstopro authored Jan 21, 2024
1 parent db9efd7 commit cd64e15
Show file tree
Hide file tree
Showing 4 changed files with 284 additions and 62 deletions.
67 changes: 67 additions & 0 deletions lib/scholar/metrics/neighbors.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
defmodule Scholar.Metrics.Neighbors do
@moduledoc """
Metrics for evaluating the results of approximate k-nearest neighbor search algorithms.
"""

import Nx.Defn

@doc """
Computes the recall of predicted k-nearest neighbors given the true k-nearest neighbors.
Recall is defined as the average fraction of nearest neighbors the algorithm predicted correctly.
## Examples
iex> neighbors_true = Nx.tensor([[0, 1], [1, 2], [2, 1]])
iex> Scholar.Metrics.Neighbors.recall(neighbors_true, neighbors_true)
#Nx.Tensor<
f32
1.0
>
iex> neighbors_true = Nx.tensor([[0, 1], [1, 2], [2, 1]])
iex> neighbors_pred = Nx.tensor([[0, 1], [1, 0], [2, 0]])
iex> Scholar.Metrics.Neighbors.recall(neighbors_true, neighbors_pred)
#Nx.Tensor<
f32
0.6666666865348816
>
"""
defn recall(neighbors_true, neighbors_pred) do
if Nx.rank(neighbors_true) != 2 do
raise ArgumentError,
"""
expected true neighbors to have shape {num_samples, num_neighbors}, \
got tensor with shape: #{inspect(Nx.shape(neighbors_true))}\
"""
end

if Nx.rank(neighbors_pred) != 2 do
raise ArgumentError,
"""
expected predicted neighbors to have shape {num_samples, num_neighbors}, \
got tensor with shape: #{inspect(Nx.shape(neighbors_pred))}\
"""
end

if Nx.axis_size(neighbors_true, 0) != Nx.axis_size(neighbors_pred, 0) do
raise ArgumentError,
"""
expected true and predicted neighbors to have the same axis 0 size, \
got #{inspect(Nx.axis_size(neighbors_true, 0))} and #{inspect(Nx.axis_size(neighbors_pred, 0))}\
"""
end

if Nx.axis_size(neighbors_true, 1) != Nx.axis_size(neighbors_pred, 1) do
raise ArgumentError,
"""
expected true and predicted neighbors to have the same axis 1 size, \
got #{inspect(Nx.axis_size(neighbors_true, 1))} and #{inspect(Nx.axis_size(neighbors_pred, 1))}\
"""
end

{n, k} = Nx.shape(neighbors_true)
concatenated = Nx.concatenate([neighbors_true, neighbors_pred], axis: 1) |> Nx.sort(axis: 1)
duplicate_mask = concatenated[[.., 0..(2 * k - 2)]] == concatenated[[.., 1..(2 * k - 1)]]
duplicate_mask |> Nx.sum() |> Nx.divide(n * k)
end
end
149 changes: 149 additions & 0 deletions lib/scholar/neighbors/large_vis.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
defmodule Scholar.Neighbors.LargeVis do
@moduledoc """
LargeVis algorithm for approximate k-nearest neighbor (k-NN) graph construction.
The algorithms works in the following way. First, the approximate k-NN graph is constructed
using a random projection forest. Then, the graph is refined by looking at the neighbors of
neighbors of every point for a fixed number of iterations. This step is called NN-expansion.
## References
* [Visualizing Large-scale and High-dimensional Data](https://arxiv.org/abs/1602.00370).
"""

import Nx.Defn
import Scholar.Shared
require Nx
alias Scholar.Neighbors.RandomProjectionForest, as: Forest
alias Scholar.Neighbors.Utils

opts = [
num_neighbors: [
required: true,
type: :pos_integer,
doc: "The number of neighbors in the graph."
],
min_leaf_size: [
type: :pos_integer,
doc: """
The minimum number of points in every leaf.
Must be at least num_neighbors.
If not provided, it is set based on the number of neighbors.
"""
],
num_trees: [
type: :pos_integer,
doc: """
The number of trees in random projection forest.
If not provided, it is set based on the dataset size.
"""
],
num_iters: [
type: :non_neg_integer,
default: 1,
doc: "The number of times to perform neighborhood expansion."
],
key: [
type: {:custom, Scholar.Options, :key, []},
doc: """
Used for random number generation in parameter initialization.
If the key is not provided, it is set to `Nx.Random.key(System.system_time())`.
"""
]
]

@opts_schema NimbleOptions.new!(opts)

@doc """
Constructs the approximate k-NN graph with LargeVis.
Returns neighbor indices and distances.
## Examples
iex> key = Nx.Random.key(12)
iex> tensor = Nx.iota({5, 2})
iex> {graph, distances} = Scholar.Neighbors.LargeVis.fit(tensor, num_neighbors: 2, min_leaf_size: 2, num_trees: 3, key: key)
iex> graph
#Nx.Tensor<
u32[5][2]
[
[0, 1],
[1, 0],
[2, 1],
[3, 2],
[4, 3]
]
>
iex> distances
#Nx.Tensor<
f32[5][2]
[
[0.0, 8.0],
[0.0, 8.0],
[0.0, 8.0],
[0.0, 8.0],
[0.0, 8.0]
]
>
"""
deftransform fit(tensor, opts) do
if Nx.rank(tensor) != 2 do
raise ArgumentError,
"""
expected input tensor to have shape {num_samples, num_features}, \
got tensor with shape: #{inspect(Nx.shape(tensor))}\
"""
end

opts = NimbleOptions.validate!(opts, @opts_schema)
k = opts[:num_neighbors]
min_leaf_size = opts[:min_leaf_size] || max(10, 2 * k)

size = Nx.axis_size(tensor, 0)
num_trees = opts[:num_trees] || 5 + round(:math.pow(size, 0.25))
key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end)

fit_n(tensor, num_neighbors: k, min_leaf_size: min_leaf_size, num_trees: num_trees, key: key)
end

defnp fit_n(tensor, opts) do
forest =
Forest.fit(tensor,
num_neighbors: opts[:num_neighbors],
min_leaf_size: opts[:min_leaf_size],
num_trees: opts[:num_trees],
key: opts[:key]
)

{graph, _} = Forest.predict(forest, tensor)
expand(graph, tensor, num_iters: opts[:num_iters])
end

defn expand(graph, tensor, opts) do
num_iters = opts[:num_iters]
{n, k} = Nx.shape(graph)

{result, _} =
while {
{
graph,
_distances = Nx.broadcast(Nx.tensor(:nan, type: to_float_type(tensor)), {n, k})
},
{tensor, iter = 0}
},
iter < num_iters do
{expansion_iter(graph, tensor), {tensor, iter + 1}}
end

result
end

defnp expansion_iter(graph, tensor) do
{size, k} = Nx.shape(graph)
candidate_indices = Nx.take(graph, graph) |> Nx.reshape({size, k * k})
candidate_indices = Nx.concatenate([graph, candidate_indices], axis: 1)

Utils.find_neighbors(tensor, tensor, candidate_indices, num_neighbors: k)
end
end
65 changes: 3 additions & 62 deletions lib/scholar/neighbors/random_projection_forest.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ defmodule Scholar.Neighbors.RandomProjectionForest do
The leaves of the trees are arranged as blocks in the field `indices`. We use the same
hyperplane for all nodes on the same level as in [2].
* [1] - Random projection trees and low dimensional manifolds
* [1] - Randomized partition trees for nearest neighbor search
* [2] - Fast Nearest Neighbor Search through Sparse Random Projections and Voting
"""

import Nx.Defn
import Scholar.Shared
require Nx
alias Scholar.Neighbors.Utils

@derive {Nx.Container,
keep: [:num_neighbors, :depth, :leaf_size, :num_trees],
Expand Down Expand Up @@ -342,7 +343,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do
|> Nx.transpose(axes: [1, 0, 2])
|> Nx.reshape({query_size, num_trees * leaf_size})

find_neighbors(query, forest.data, candidate_indices, num_neighbors: k)
Utils.find_neighbors(query, forest.data, candidate_indices, num_neighbors: k)
end

defnp compute_start_indices(forest, query) do
Expand Down Expand Up @@ -400,64 +401,4 @@ defmodule Scholar.Neighbors.RandomProjectionForest do
defnp left_child(nodes), do: 2 * nodes + 1

defnp right_child(nodes), do: 2 * nodes + 2

defnp find_neighbors(query, data, candidate_indices, opts) do
k = opts[:num_neighbors]
{size, length} = Nx.shape(candidate_indices)

distances =
query
|> Nx.new_axis(1)
|> Nx.subtract(Nx.take(data, candidate_indices))
|> Nx.pow(2)
|> Nx.sum(axes: [2])

distances =
if length > 1 do
sorted_indices = Nx.argsort(candidate_indices, axis: 1, stable: true)
inverse = inverse_permutation(sorted_indices)
sorted = Nx.take_along_axis(candidate_indices, sorted_indices, axis: 1)

duplicate_mask =
Nx.concatenate(
[
Nx.broadcast(0, {size, 1}),
Nx.equal(sorted[[.., 0..-2//1]], sorted[[.., 1..-1//1]])
],
axis: 1
)
|> Nx.take_along_axis(inverse, axis: 1)

Nx.select(duplicate_mask, :infinity, distances)
else
distances
end

indices = Nx.argsort(distances, axis: 1) |> Nx.slice_along_axis(0, k, axis: 1)

neighbor_indices =
Nx.take(
Nx.vectorize(candidate_indices, :samples),
Nx.vectorize(indices, :samples)
)
|> Nx.devectorize()
|> Nx.rename(nil)

neighbor_distances = Nx.take_along_axis(distances, indices, axis: 1)

{neighbor_indices, neighbor_distances}
end

defnp inverse_permutation(indices) do
{size, length} = Nx.shape(indices)
target = Nx.broadcast(Nx.u32(0), {size, length})
samples = Nx.iota({size, length, 1}, axis: 0)

indices =
Nx.concatenate([samples, Nx.new_axis(indices, 2)], axis: 2)
|> Nx.reshape({size * length, 2})

updates = Nx.iota({size, length}, axis: 1) |> Nx.reshape({size * length})
Nx.indexed_add(target, indices, updates)
end
end
65 changes: 65 additions & 0 deletions lib/scholar/neighbors/utils.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
defmodule Scholar.Neighbors.Utils do
@moduledoc false
import Nx.Defn
require Nx

defn find_neighbors(query, data, candidate_indices, opts) do
k = opts[:num_neighbors]
{size, length} = Nx.shape(candidate_indices)

distances =
query
|> Nx.new_axis(1)
|> Nx.subtract(Nx.take(data, candidate_indices))
|> Nx.pow(2)
|> Nx.sum(axes: [2])

distances =
if length > 1 do
sorted_indices = Nx.argsort(candidate_indices, axis: 1, stable: true)
inverse = inverse_permutation(sorted_indices)
sorted = Nx.take_along_axis(candidate_indices, sorted_indices, axis: 1)

duplicate_mask =
Nx.concatenate(
[
Nx.broadcast(0, {size, 1}),
Nx.equal(sorted[[.., 0..-2//1]], sorted[[.., 1..-1//1]])
],
axis: 1
)
|> Nx.take_along_axis(inverse, axis: 1)

Nx.select(duplicate_mask, :infinity, distances)
else
distances
end

indices = Nx.argsort(distances, axis: 1) |> Nx.slice_along_axis(0, k, axis: 1)

neighbor_indices =
Nx.take(
Nx.vectorize(candidate_indices, :samples),
Nx.vectorize(indices, :samples)
)
|> Nx.devectorize()
|> Nx.rename(nil)

neighbor_distances = Nx.take_along_axis(distances, indices, axis: 1)

{neighbor_indices, neighbor_distances}
end

defnp inverse_permutation(indices) do
{size, length} = Nx.shape(indices)
target = Nx.broadcast(Nx.u32(0), {size, length})
samples = Nx.iota({size, length, 1}, axis: 0)

indices =
Nx.concatenate([samples, Nx.new_axis(indices, 2)], axis: 2)
|> Nx.reshape({size * length, 2})

updates = Nx.iota({size, length}, axis: 1) |> Nx.reshape({size * length})
Nx.indexed_add(target, indices, updates)
end
end

0 comments on commit cd64e15

Please sign in to comment.