forked from elixir-nx/scholar
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
284 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
defmodule Scholar.Metrics.Neighbors do | ||
@moduledoc """ | ||
Metrics for evaluating the results of approximate k-nearest neighbor search algorithms. | ||
""" | ||
|
||
import Nx.Defn | ||
|
||
@doc """ | ||
Computes the recall of predicted k-nearest neighbors given the true k-nearest neighbors. | ||
Recall is defined as the average fraction of nearest neighbors the algorithm predicted correctly. | ||
## Examples | ||
iex> neighbors_true = Nx.tensor([[0, 1], [1, 2], [2, 1]]) | ||
iex> Scholar.Metrics.Neighbors.recall(neighbors_true, neighbors_true) | ||
#Nx.Tensor< | ||
f32 | ||
1.0 | ||
> | ||
iex> neighbors_true = Nx.tensor([[0, 1], [1, 2], [2, 1]]) | ||
iex> neighbors_pred = Nx.tensor([[0, 1], [1, 0], [2, 0]]) | ||
iex> Scholar.Metrics.Neighbors.recall(neighbors_true, neighbors_pred) | ||
#Nx.Tensor< | ||
f32 | ||
0.6666666865348816 | ||
> | ||
""" | ||
defn recall(neighbors_true, neighbors_pred) do | ||
if Nx.rank(neighbors_true) != 2 do | ||
raise ArgumentError, | ||
""" | ||
expected true neighbors to have shape {num_samples, num_neighbors}, \ | ||
got tensor with shape: #{inspect(Nx.shape(neighbors_true))}\ | ||
""" | ||
end | ||
|
||
if Nx.rank(neighbors_pred) != 2 do | ||
raise ArgumentError, | ||
""" | ||
expected predicted neighbors to have shape {num_samples, num_neighbors}, \ | ||
got tensor with shape: #{inspect(Nx.shape(neighbors_pred))}\ | ||
""" | ||
end | ||
|
||
if Nx.axis_size(neighbors_true, 0) != Nx.axis_size(neighbors_pred, 0) do | ||
raise ArgumentError, | ||
""" | ||
expected true and predicted neighbors to have the same axis 0 size, \ | ||
got #{inspect(Nx.axis_size(neighbors_true, 0))} and #{inspect(Nx.axis_size(neighbors_pred, 0))}\ | ||
""" | ||
end | ||
|
||
if Nx.axis_size(neighbors_true, 1) != Nx.axis_size(neighbors_pred, 1) do | ||
raise ArgumentError, | ||
""" | ||
expected true and predicted neighbors to have the same axis 1 size, \ | ||
got #{inspect(Nx.axis_size(neighbors_true, 1))} and #{inspect(Nx.axis_size(neighbors_pred, 1))}\ | ||
""" | ||
end | ||
|
||
{n, k} = Nx.shape(neighbors_true) | ||
concatenated = Nx.concatenate([neighbors_true, neighbors_pred], axis: 1) |> Nx.sort(axis: 1) | ||
duplicate_mask = concatenated[[.., 0..(2 * k - 2)]] == concatenated[[.., 1..(2 * k - 1)]] | ||
duplicate_mask |> Nx.sum() |> Nx.divide(n * k) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
defmodule Scholar.Neighbors.LargeVis do | ||
@moduledoc """ | ||
LargeVis algorithm for approximate k-nearest neighbor (k-NN) graph construction. | ||
The algorithms works in the following way. First, the approximate k-NN graph is constructed | ||
using a random projection forest. Then, the graph is refined by looking at the neighbors of | ||
neighbors of every point for a fixed number of iterations. This step is called NN-expansion. | ||
## References | ||
* [Visualizing Large-scale and High-dimensional Data](https://arxiv.org/abs/1602.00370). | ||
""" | ||
|
||
import Nx.Defn | ||
import Scholar.Shared | ||
require Nx | ||
alias Scholar.Neighbors.RandomProjectionForest, as: Forest | ||
alias Scholar.Neighbors.Utils | ||
|
||
opts = [ | ||
num_neighbors: [ | ||
required: true, | ||
type: :pos_integer, | ||
doc: "The number of neighbors in the graph." | ||
], | ||
min_leaf_size: [ | ||
type: :pos_integer, | ||
doc: """ | ||
The minimum number of points in every leaf. | ||
Must be at least num_neighbors. | ||
If not provided, it is set based on the number of neighbors. | ||
""" | ||
], | ||
num_trees: [ | ||
type: :pos_integer, | ||
doc: """ | ||
The number of trees in random projection forest. | ||
If not provided, it is set based on the dataset size. | ||
""" | ||
], | ||
num_iters: [ | ||
type: :non_neg_integer, | ||
default: 1, | ||
doc: "The number of times to perform neighborhood expansion." | ||
], | ||
key: [ | ||
type: {:custom, Scholar.Options, :key, []}, | ||
doc: """ | ||
Used for random number generation in parameter initialization. | ||
If the key is not provided, it is set to `Nx.Random.key(System.system_time())`. | ||
""" | ||
] | ||
] | ||
|
||
@opts_schema NimbleOptions.new!(opts) | ||
|
||
@doc """ | ||
Constructs the approximate k-NN graph with LargeVis. | ||
Returns neighbor indices and distances. | ||
## Examples | ||
iex> key = Nx.Random.key(12) | ||
iex> tensor = Nx.iota({5, 2}) | ||
iex> {graph, distances} = Scholar.Neighbors.LargeVis.fit(tensor, num_neighbors: 2, min_leaf_size: 2, num_trees: 3, key: key) | ||
iex> graph | ||
#Nx.Tensor< | ||
u32[5][2] | ||
[ | ||
[0, 1], | ||
[1, 0], | ||
[2, 1], | ||
[3, 2], | ||
[4, 3] | ||
] | ||
> | ||
iex> distances | ||
#Nx.Tensor< | ||
f32[5][2] | ||
[ | ||
[0.0, 8.0], | ||
[0.0, 8.0], | ||
[0.0, 8.0], | ||
[0.0, 8.0], | ||
[0.0, 8.0] | ||
] | ||
> | ||
""" | ||
deftransform fit(tensor, opts) do | ||
if Nx.rank(tensor) != 2 do | ||
raise ArgumentError, | ||
""" | ||
expected input tensor to have shape {num_samples, num_features}, \ | ||
got tensor with shape: #{inspect(Nx.shape(tensor))}\ | ||
""" | ||
end | ||
|
||
opts = NimbleOptions.validate!(opts, @opts_schema) | ||
k = opts[:num_neighbors] | ||
min_leaf_size = opts[:min_leaf_size] || max(10, 2 * k) | ||
|
||
size = Nx.axis_size(tensor, 0) | ||
num_trees = opts[:num_trees] || 5 + round(:math.pow(size, 0.25)) | ||
key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) | ||
|
||
fit_n(tensor, num_neighbors: k, min_leaf_size: min_leaf_size, num_trees: num_trees, key: key) | ||
end | ||
|
||
defnp fit_n(tensor, opts) do | ||
forest = | ||
Forest.fit(tensor, | ||
num_neighbors: opts[:num_neighbors], | ||
min_leaf_size: opts[:min_leaf_size], | ||
num_trees: opts[:num_trees], | ||
key: opts[:key] | ||
) | ||
|
||
{graph, _} = Forest.predict(forest, tensor) | ||
expand(graph, tensor, num_iters: opts[:num_iters]) | ||
end | ||
|
||
defn expand(graph, tensor, opts) do | ||
num_iters = opts[:num_iters] | ||
{n, k} = Nx.shape(graph) | ||
|
||
{result, _} = | ||
while { | ||
{ | ||
graph, | ||
_distances = Nx.broadcast(Nx.tensor(:nan, type: to_float_type(tensor)), {n, k}) | ||
}, | ||
{tensor, iter = 0} | ||
}, | ||
iter < num_iters do | ||
{expansion_iter(graph, tensor), {tensor, iter + 1}} | ||
end | ||
|
||
result | ||
end | ||
|
||
defnp expansion_iter(graph, tensor) do | ||
{size, k} = Nx.shape(graph) | ||
candidate_indices = Nx.take(graph, graph) |> Nx.reshape({size, k * k}) | ||
candidate_indices = Nx.concatenate([graph, candidate_indices], axis: 1) | ||
|
||
Utils.find_neighbors(tensor, tensor, candidate_indices, num_neighbors: k) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
defmodule Scholar.Neighbors.Utils do | ||
@moduledoc false | ||
import Nx.Defn | ||
require Nx | ||
|
||
defn find_neighbors(query, data, candidate_indices, opts) do | ||
k = opts[:num_neighbors] | ||
{size, length} = Nx.shape(candidate_indices) | ||
|
||
distances = | ||
query | ||
|> Nx.new_axis(1) | ||
|> Nx.subtract(Nx.take(data, candidate_indices)) | ||
|> Nx.pow(2) | ||
|> Nx.sum(axes: [2]) | ||
|
||
distances = | ||
if length > 1 do | ||
sorted_indices = Nx.argsort(candidate_indices, axis: 1, stable: true) | ||
inverse = inverse_permutation(sorted_indices) | ||
sorted = Nx.take_along_axis(candidate_indices, sorted_indices, axis: 1) | ||
|
||
duplicate_mask = | ||
Nx.concatenate( | ||
[ | ||
Nx.broadcast(0, {size, 1}), | ||
Nx.equal(sorted[[.., 0..-2//1]], sorted[[.., 1..-1//1]]) | ||
], | ||
axis: 1 | ||
) | ||
|> Nx.take_along_axis(inverse, axis: 1) | ||
|
||
Nx.select(duplicate_mask, :infinity, distances) | ||
else | ||
distances | ||
end | ||
|
||
indices = Nx.argsort(distances, axis: 1) |> Nx.slice_along_axis(0, k, axis: 1) | ||
|
||
neighbor_indices = | ||
Nx.take( | ||
Nx.vectorize(candidate_indices, :samples), | ||
Nx.vectorize(indices, :samples) | ||
) | ||
|> Nx.devectorize() | ||
|> Nx.rename(nil) | ||
|
||
neighbor_distances = Nx.take_along_axis(distances, indices, axis: 1) | ||
|
||
{neighbor_indices, neighbor_distances} | ||
end | ||
|
||
defnp inverse_permutation(indices) do | ||
{size, length} = Nx.shape(indices) | ||
target = Nx.broadcast(Nx.u32(0), {size, length}) | ||
samples = Nx.iota({size, length, 1}, axis: 0) | ||
|
||
indices = | ||
Nx.concatenate([samples, Nx.new_axis(indices, 2)], axis: 2) | ||
|> Nx.reshape({size * length, 2}) | ||
|
||
updates = Nx.iota({size, length}, axis: 1) |> Nx.reshape({size * length}) | ||
Nx.indexed_add(target, indices, updates) | ||
end | ||
end |