diff --git a/lib/scholar/metrics/neighbors.ex b/lib/scholar/metrics/neighbors.ex new file mode 100644 index 00000000..969d2db7 --- /dev/null +++ b/lib/scholar/metrics/neighbors.ex @@ -0,0 +1,67 @@ +defmodule Scholar.Metrics.Neighbors do + @moduledoc """ + Metrics for evaluating the results of approximate k-nearest neighbor search algorithms. + """ + + import Nx.Defn + + @doc """ + Computes the recall of predicted k-nearest neighbors given the true k-nearest neighbors. + Recall is defined as the average fraction of nearest neighbors the algorithm predicted correctly. + + ## Examples + + iex> neighbors_true = Nx.tensor([[0, 1], [1, 2], [2, 1]]) + iex> Scholar.Metrics.Neighbors.recall(neighbors_true, neighbors_true) + #Nx.Tensor< + f32 + 1.0 + > + + iex> neighbors_true = Nx.tensor([[0, 1], [1, 2], [2, 1]]) + iex> neighbors_pred = Nx.tensor([[0, 1], [1, 0], [2, 0]]) + iex> Scholar.Metrics.Neighbors.recall(neighbors_true, neighbors_pred) + #Nx.Tensor< + f32 + 0.6666666865348816 + > + """ + defn recall(neighbors_true, neighbors_pred) do + if Nx.rank(neighbors_true) != 2 do + raise ArgumentError, + """ + expected true neighbors to have shape {num_samples, num_neighbors}, \ + got tensor with shape: #{inspect(Nx.shape(neighbors_true))}\ + """ + end + + if Nx.rank(neighbors_pred) != 2 do + raise ArgumentError, + """ + expected predicted neighbors to have shape {num_samples, num_neighbors}, \ + got tensor with shape: #{inspect(Nx.shape(neighbors_pred))}\ + """ + end + + if Nx.axis_size(neighbors_true, 0) != Nx.axis_size(neighbors_pred, 0) do + raise ArgumentError, + """ + expected true and predicted neighbors to have the same axis 0 size, \ + got #{inspect(Nx.axis_size(neighbors_true, 0))} and #{inspect(Nx.axis_size(neighbors_pred, 0))}\ + """ + end + + if Nx.axis_size(neighbors_true, 1) != Nx.axis_size(neighbors_pred, 1) do + raise ArgumentError, + """ + expected true and predicted neighbors to have the same axis 1 size, \ + got #{inspect(Nx.axis_size(neighbors_true, 1))} and #{inspect(Nx.axis_size(neighbors_pred, 1))}\ + """ + end + + {n, k} = Nx.shape(neighbors_true) + concatenated = Nx.concatenate([neighbors_true, neighbors_pred], axis: 1) |> Nx.sort(axis: 1) + duplicate_mask = concatenated[[.., 0..(2 * k - 2)]] == concatenated[[.., 1..(2 * k - 1)]] + duplicate_mask |> Nx.sum() |> Nx.divide(n * k) + end +end diff --git a/lib/scholar/neighbors/large_vis.ex b/lib/scholar/neighbors/large_vis.ex new file mode 100644 index 00000000..3aeb7a7a --- /dev/null +++ b/lib/scholar/neighbors/large_vis.ex @@ -0,0 +1,149 @@ +defmodule Scholar.Neighbors.LargeVis do + @moduledoc """ + LargeVis algorithm for approximate k-nearest neighbor (k-NN) graph construction. + + The algorithms works in the following way. First, the approximate k-NN graph is constructed + using a random projection forest. Then, the graph is refined by looking at the neighbors of + neighbors of every point for a fixed number of iterations. This step is called NN-expansion. + + ## References + + * [Visualizing Large-scale and High-dimensional Data](https://arxiv.org/abs/1602.00370). + """ + + import Nx.Defn + import Scholar.Shared + require Nx + alias Scholar.Neighbors.RandomProjectionForest, as: Forest + alias Scholar.Neighbors.Utils + + opts = [ + num_neighbors: [ + required: true, + type: :pos_integer, + doc: "The number of neighbors in the graph." + ], + min_leaf_size: [ + type: :pos_integer, + doc: """ + The minimum number of points in every leaf. + Must be at least num_neighbors. + If not provided, it is set based on the number of neighbors. + """ + ], + num_trees: [ + type: :pos_integer, + doc: """ + The number of trees in random projection forest. + If not provided, it is set based on the dataset size. + """ + ], + num_iters: [ + type: :non_neg_integer, + default: 1, + doc: "The number of times to perform neighborhood expansion." + ], + key: [ + type: {:custom, Scholar.Options, :key, []}, + doc: """ + Used for random number generation in parameter initialization. + If the key is not provided, it is set to `Nx.Random.key(System.system_time())`. + """ + ] + ] + + @opts_schema NimbleOptions.new!(opts) + + @doc """ + Constructs the approximate k-NN graph with LargeVis. + + Returns neighbor indices and distances. + + ## Examples + + iex> key = Nx.Random.key(12) + iex> tensor = Nx.iota({5, 2}) + iex> {graph, distances} = Scholar.Neighbors.LargeVis.fit(tensor, num_neighbors: 2, min_leaf_size: 2, num_trees: 3, key: key) + iex> graph + #Nx.Tensor< + u32[5][2] + [ + [0, 1], + [1, 0], + [2, 1], + [3, 2], + [4, 3] + ] + > + iex> distances + #Nx.Tensor< + f32[5][2] + [ + [0.0, 8.0], + [0.0, 8.0], + [0.0, 8.0], + [0.0, 8.0], + [0.0, 8.0] + ] + > + """ + deftransform fit(tensor, opts) do + if Nx.rank(tensor) != 2 do + raise ArgumentError, + """ + expected input tensor to have shape {num_samples, num_features}, \ + got tensor with shape: #{inspect(Nx.shape(tensor))}\ + """ + end + + opts = NimbleOptions.validate!(opts, @opts_schema) + k = opts[:num_neighbors] + min_leaf_size = opts[:min_leaf_size] || max(10, 2 * k) + + size = Nx.axis_size(tensor, 0) + num_trees = opts[:num_trees] || 5 + round(:math.pow(size, 0.25)) + key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) + + fit_n(tensor, num_neighbors: k, min_leaf_size: min_leaf_size, num_trees: num_trees, key: key) + end + + defnp fit_n(tensor, opts) do + forest = + Forest.fit(tensor, + num_neighbors: opts[:num_neighbors], + min_leaf_size: opts[:min_leaf_size], + num_trees: opts[:num_trees], + key: opts[:key] + ) + + {graph, _} = Forest.predict(forest, tensor) + expand(graph, tensor, num_iters: opts[:num_iters]) + end + + defn expand(graph, tensor, opts) do + num_iters = opts[:num_iters] + {n, k} = Nx.shape(graph) + + {result, _} = + while { + { + graph, + _distances = Nx.broadcast(Nx.tensor(:nan, type: to_float_type(tensor)), {n, k}) + }, + {tensor, iter = 0} + }, + iter < num_iters do + {expansion_iter(graph, tensor), {tensor, iter + 1}} + end + + result + end + + defnp expansion_iter(graph, tensor) do + {size, k} = Nx.shape(graph) + candidate_indices = Nx.take(graph, graph) |> Nx.reshape({size, k * k}) + candidate_indices = Nx.concatenate([graph, candidate_indices], axis: 1) + + Utils.find_neighbors(tensor, tensor, candidate_indices, num_neighbors: k) + end +end diff --git a/lib/scholar/neighbors/random_projection_forest.ex b/lib/scholar/neighbors/random_projection_forest.ex index dacc1dc9..aa4ee31a 100644 --- a/lib/scholar/neighbors/random_projection_forest.ex +++ b/lib/scholar/neighbors/random_projection_forest.ex @@ -13,13 +13,14 @@ defmodule Scholar.Neighbors.RandomProjectionForest do The leaves of the trees are arranged as blocks in the field `indices`. We use the same hyperplane for all nodes on the same level as in [2]. - * [1] - Random projection trees and low dimensional manifolds + * [1] - Randomized partition trees for nearest neighbor search * [2] - Fast Nearest Neighbor Search through Sparse Random Projections and Voting """ import Nx.Defn import Scholar.Shared require Nx + alias Scholar.Neighbors.Utils @derive {Nx.Container, keep: [:num_neighbors, :depth, :leaf_size, :num_trees], @@ -342,7 +343,7 @@ defmodule Scholar.Neighbors.RandomProjectionForest do |> Nx.transpose(axes: [1, 0, 2]) |> Nx.reshape({query_size, num_trees * leaf_size}) - find_neighbors(query, forest.data, candidate_indices, num_neighbors: k) + Utils.find_neighbors(query, forest.data, candidate_indices, num_neighbors: k) end defnp compute_start_indices(forest, query) do @@ -400,64 +401,4 @@ defmodule Scholar.Neighbors.RandomProjectionForest do defnp left_child(nodes), do: 2 * nodes + 1 defnp right_child(nodes), do: 2 * nodes + 2 - - defnp find_neighbors(query, data, candidate_indices, opts) do - k = opts[:num_neighbors] - {size, length} = Nx.shape(candidate_indices) - - distances = - query - |> Nx.new_axis(1) - |> Nx.subtract(Nx.take(data, candidate_indices)) - |> Nx.pow(2) - |> Nx.sum(axes: [2]) - - distances = - if length > 1 do - sorted_indices = Nx.argsort(candidate_indices, axis: 1, stable: true) - inverse = inverse_permutation(sorted_indices) - sorted = Nx.take_along_axis(candidate_indices, sorted_indices, axis: 1) - - duplicate_mask = - Nx.concatenate( - [ - Nx.broadcast(0, {size, 1}), - Nx.equal(sorted[[.., 0..-2//1]], sorted[[.., 1..-1//1]]) - ], - axis: 1 - ) - |> Nx.take_along_axis(inverse, axis: 1) - - Nx.select(duplicate_mask, :infinity, distances) - else - distances - end - - indices = Nx.argsort(distances, axis: 1) |> Nx.slice_along_axis(0, k, axis: 1) - - neighbor_indices = - Nx.take( - Nx.vectorize(candidate_indices, :samples), - Nx.vectorize(indices, :samples) - ) - |> Nx.devectorize() - |> Nx.rename(nil) - - neighbor_distances = Nx.take_along_axis(distances, indices, axis: 1) - - {neighbor_indices, neighbor_distances} - end - - defnp inverse_permutation(indices) do - {size, length} = Nx.shape(indices) - target = Nx.broadcast(Nx.u32(0), {size, length}) - samples = Nx.iota({size, length, 1}, axis: 0) - - indices = - Nx.concatenate([samples, Nx.new_axis(indices, 2)], axis: 2) - |> Nx.reshape({size * length, 2}) - - updates = Nx.iota({size, length}, axis: 1) |> Nx.reshape({size * length}) - Nx.indexed_add(target, indices, updates) - end end diff --git a/lib/scholar/neighbors/utils.ex b/lib/scholar/neighbors/utils.ex new file mode 100644 index 00000000..f202348b --- /dev/null +++ b/lib/scholar/neighbors/utils.ex @@ -0,0 +1,65 @@ +defmodule Scholar.Neighbors.Utils do + @moduledoc false + import Nx.Defn + require Nx + + defn find_neighbors(query, data, candidate_indices, opts) do + k = opts[:num_neighbors] + {size, length} = Nx.shape(candidate_indices) + + distances = + query + |> Nx.new_axis(1) + |> Nx.subtract(Nx.take(data, candidate_indices)) + |> Nx.pow(2) + |> Nx.sum(axes: [2]) + + distances = + if length > 1 do + sorted_indices = Nx.argsort(candidate_indices, axis: 1, stable: true) + inverse = inverse_permutation(sorted_indices) + sorted = Nx.take_along_axis(candidate_indices, sorted_indices, axis: 1) + + duplicate_mask = + Nx.concatenate( + [ + Nx.broadcast(0, {size, 1}), + Nx.equal(sorted[[.., 0..-2//1]], sorted[[.., 1..-1//1]]) + ], + axis: 1 + ) + |> Nx.take_along_axis(inverse, axis: 1) + + Nx.select(duplicate_mask, :infinity, distances) + else + distances + end + + indices = Nx.argsort(distances, axis: 1) |> Nx.slice_along_axis(0, k, axis: 1) + + neighbor_indices = + Nx.take( + Nx.vectorize(candidate_indices, :samples), + Nx.vectorize(indices, :samples) + ) + |> Nx.devectorize() + |> Nx.rename(nil) + + neighbor_distances = Nx.take_along_axis(distances, indices, axis: 1) + + {neighbor_indices, neighbor_distances} + end + + defnp inverse_permutation(indices) do + {size, length} = Nx.shape(indices) + target = Nx.broadcast(Nx.u32(0), {size, length}) + samples = Nx.iota({size, length, 1}, axis: 0) + + indices = + Nx.concatenate([samples, Nx.new_axis(indices, 2)], axis: 2) + |> Nx.reshape({size * length, 2}) + + updates = Nx.iota({size, length}, axis: 1) |> Nx.reshape({size * length}) + Nx.indexed_add(target, indices, updates) + end +end