diff --git a/benchmarks/knn.exs b/benchmarks/knn.exs new file mode 100644 index 00000000..a0bc3d49 --- /dev/null +++ b/benchmarks/knn.exs @@ -0,0 +1,27 @@ +# mix run benchmarks/knn.exs +Nx.global_default_backend(EXLA.Backend) +Nx.Defn.global_default_options(compiler: EXLA) + +key = Nx.Random.key(System.os_time()) + +inputs_knn = %{ + "100x10" => elem(Nx.Random.uniform(key, 0, 100, shape: {100, 10}), 0), + "1000x10" => elem(Nx.Random.uniform(key, 0, 1000, shape: {1000, 10}), 0), + "10000x10" => elem(Nx.Random.uniform(key, 0, 10000, shape: {10000, 10}), 0) +} + +Benchee.run( + %{ + "kdtree" => fn x -> + kdtree = Scholar.Neighbors.KDTree.fit_bounded(x, Nx.axis_size(x, 0)) + Scholar.Neighbors.KDTree.predict(kdtree, x, k: 4) + end, + "brute force knn" => fn x -> + model = Scholar.Neighbors.KNearestNeighbors.fit(x, Nx.broadcast(1, {Nx.axis_size(x, 0)}), num_classes: 2, num_neighbors: 4) + Scholar.Neighbors.KNearestNeighbors.k_neighbors(model, x) + end + }, + time: 10, + memory_time: 2, + inputs: inputs_knn +) diff --git a/lib/scholar/neighbors/kd_tree.ex b/lib/scholar/neighbors/kd_tree.ex index c80077d8..d95eee30 100644 --- a/lib/scholar/neighbors/kd_tree.ex +++ b/lib/scholar/neighbors/kd_tree.ex @@ -9,13 +9,13 @@ defmodule Scholar.Neighbors.KDTree do Two construction modes are available: - * `bounded/2` - the tensor has min and max values with an amplitude given by `max - min`. + * `fit_bounded/2` - the tensor has min and max values with an amplitude given by `max - min`. It is also guaranteed that the `amplitude * levels(tensor) + 1` does not overflow the tensor. See `amplitude/1` to verify if this holds. This implementation happens fully within `defn`. This version is orders of magnitude faster than the `unbounded/2` one. - * `unbounded/2` - there are no known bounds (min and max values) to the tensor. + * `fit_unbounded/2` - there are no known bounds (min and max values) to the tensor. This implementation is recursive and goes in and out of the `defn`, therefore it cannot be called inside `defn`. @@ -28,10 +28,33 @@ defmodule Scholar.Neighbors.KDTree do """ import Nx.Defn - - @derive {Nx.Container, keep: [:levels], containers: [:indexes, :data]} - @enforce_keys [:levels, :indexes, :data] - defstruct [:levels, :indexes, :data] + alias Scholar.Metrics.Distance + + @derive {Nx.Container, keep: [:levels], containers: [:indices, :data]} + @enforce_keys [:levels, :indices, :data] + defstruct [:levels, :indices, :data] + + opts = [ + k: [ + type: :pos_integer, + default: 3, + doc: "The number of neighbors to use by default for `k_neighbors` queries" + ], + metric: [ + type: {:custom, Scholar.Options, :metric, []}, + default: {:minkowski, 2}, + doc: ~S""" + Name of the metric. Possible values: + + * `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`) + we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric. + + * `:cosine` - Cosine metric. + """ + ] + ] + + @predict_schema NimbleOptions.new!(opts) @doc """ Builds a KDTree without known min-max bounds. @@ -46,19 +69,19 @@ defmodule Scholar.Neighbors.KDTree do ## Examples - iex> Scholar.Neighbors.KDTree.unbounded(Nx.iota({5, 2}), compiler: EXLA) + iex> Scholar.Neighbors.KDTree.fit_unbounded(Nx.iota({5, 2}), compiler: EXLA) %Scholar.Neighbors.KDTree{ data: Nx.iota({5, 2}), levels: 3, - indexes: Nx.u32([3, 1, 4, 0, 2]) + indices: Nx.u32([3, 1, 4, 0, 2]) } """ - def unbounded(tensor, opts \\ []) do + def fit_unbounded(tensor, opts \\ []) do levels = levels(tensor) {size, _dims} = Nx.shape(tensor) - indexes = + indices = if size > 2 do subtree_size = unbounded_subtree_size(1, levels, size) {left, mid, right} = Nx.Defn.jit_apply(&root_slice(&1, subtree_size), [tensor], opts) @@ -70,7 +93,7 @@ defmodule Scholar.Neighbors.KDTree do Nx.argsort(tensor[[.., 0]], direction: :desc, type: :u32) end - %__MODULE__{levels: levels, indexes: indexes, data: tensor} + %__MODULE__{levels: levels, indices: indices, data: tensor} end defp recur([{_i, %Nx.Tensor{shape: {1}} = leaf} | rest], next, acc, tensor, level, levels, opts) do @@ -85,13 +108,13 @@ defmodule Scholar.Neighbors.KDTree do recur(rest, next, acc, tensor, level, levels, opts) end - defp recur([{i, indexes} | rest], next, acc, tensor, level, levels, opts) do + defp recur([{i, indices} | rest], next, acc, tensor, level, levels, opts) do %Nx.Tensor{shape: {size, dims}} = tensor k = rem(level, dims) subtree_size = unbounded_subtree_size(left_child(i), levels, size) {left, mid, right} = - Nx.Defn.jit_apply(&recur_slice(&1, &2, &3, subtree_size), [tensor, indexes, k], opts) + Nx.Defn.jit_apply(&recur_slice(&1, &2, &3, subtree_size), [tensor, indices, k], opts) next = [{right_child(i), right}, {left_child(i), left} | next] acc = <> @@ -107,18 +130,18 @@ defmodule Scholar.Neighbors.KDTree do end defp root_slice(tensor, subtree_size) do - indexes = Nx.argsort(tensor[[.., 0]], type: :u32) + indices = Nx.argsort(tensor[[.., 0]], type: :u32) - {Nx.slice(indexes, [0], [subtree_size]), indexes[subtree_size], - Nx.slice(indexes, [subtree_size + 1], [Nx.size(indexes) - subtree_size - 1])} + {Nx.slice(indices, [0], [subtree_size]), indices[subtree_size], + Nx.slice(indices, [subtree_size + 1], [Nx.size(indices) - subtree_size - 1])} end - defp recur_slice(tensor, indexes, k, subtree_size) do - sorted = Nx.argsort(Nx.take(tensor, indexes)[[.., k]], type: :u32) - indexes = Nx.take(indexes, sorted) + defp recur_slice(tensor, indices, k, subtree_size) do + sorted = Nx.argsort(Nx.take(tensor, indices)[[.., k]], type: :u32) + indices = Nx.take(indices, sorted) - {Nx.slice(indexes, [0], [subtree_size]), indexes[subtree_size], - Nx.slice(indexes, [subtree_size + 1], [Nx.size(indexes) - subtree_size - 1])} + {Nx.slice(indices, [0], [subtree_size]), indices[subtree_size], + Nx.slice(indices, [subtree_size + 1], [Nx.size(indices) - subtree_size - 1])} end defp unbounded_subtree_size(i, levels, size) do @@ -145,18 +168,18 @@ defmodule Scholar.Neighbors.KDTree do ## Examples - iex> Scholar.Neighbors.KDTree.bounded(Nx.iota({5, 2}), 10) + iex> Scholar.Neighbors.KDTree.fit_bounded(Nx.iota({5, 2}), 10) %Scholar.Neighbors.KDTree{ data: Nx.iota({5, 2}), levels: 3, - indexes: Nx.u32([3, 1, 4, 0, 2]) + indices: Nx.u32([3, 1, 4, 0, 2]) } """ - deftransform bounded(tensor, amplitude) do - %__MODULE__{levels: levels(tensor), indexes: bounded_n(tensor, amplitude), data: tensor} + deftransform fit_bounded(tensor, amplitude) do + %__MODULE__{levels: levels(tensor), indices: fit_bounded_n(tensor, amplitude), data: tensor} end - defnp bounded_n(tensor, amplitude) do + defnp fit_bounded_n(tensor, amplitude) do levels = levels(tensor) {size, dims} = Nx.shape(tensor) band = amplitude + 1 @@ -165,8 +188,8 @@ defmodule Scholar.Neighbors.KDTree do {level, tags, _tensor, _band} = while {level = Nx.u32(0), tags, tensor, band}, level < levels - 1 do k = rem(level, dims) - indexes = Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) - tags = update_tags(tags, indexes, level, levels, size) + indices = Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) + tags = update_tags(tags, indices, level, levels, size) {level + 1, tags, tensor, band} end @@ -174,8 +197,8 @@ defmodule Scholar.Neighbors.KDTree do Nx.argsort(tensor[[.., k]] + band * tags, type: :u32) end - defnp update_tags(tags, indexes, level, levels, size) do - pos = Nx.argsort(indexes, type: :u32) + defnp update_tags(tags, indices, level, levels, size) do + pos = Nx.argsort(indices, type: :u32) pivot = bounded_segment_begin(tags, levels, size) + @@ -224,8 +247,8 @@ defmodule Scholar.Neighbors.KDTree do @doc """ Returns the amplitude of a bounded tensor. - If -1 is returned, it means the tensor cannot use the `bounded` algorithm - to generate a KDTree and `unbounded/2` must be used instead. + If -1 is returned, it means the tensor cannot use the `fit_bounded` algorithm + to generate a KDTree and `fit_unbounded/2` must be used instead. This cannot be invoked inside a `defn`. @@ -299,6 +322,10 @@ defmodule Scholar.Neighbors.KDTree do > """ + deftransform parent(0) do + -1 + end + deftransform parent(i) when is_integer(i), do: div(i - 1, 2) deftransform parent(%Nx.Tensor{} = t), do: Nx.quotient(Nx.subtract(t, 1), 2) @@ -347,4 +374,225 @@ defmodule Scholar.Neighbors.KDTree do """ deftransform right_child(i) when is_integer(i), do: 2 * i + 2 deftransform right_child(%Nx.Tensor{} = t), do: Nx.add(Nx.multiply(2, t), 2) + + @doc """ + Predict the K nearest neighbors of `x_predict` in KDTree. + + ## Examples + + iex> x = Nx.iota({10, 2}) + iex> x_predict = Nx.tensor([[2, 5], [1, 9], [6, 4]]) + iex> kdtree = Scholar.Neighbors.KDTree.fit_bounded(x, 20) + iex> Scholar.Neighbors.KDTree.predict(kdtree, x_predict, k: 3) + #Nx.Tensor< + s64[3][3] + [ + [2, 1, 0], + [2, 3, 1], + [2, 3, 1] + ] + > + iex> Scholar.Neighbors.KDTree.predict(kdtree, x_predict, k: 3, metric: {:minkowski, 1}) + #Nx.Tensor< + s64[3][3] + [ + [2, 1, 0], + [4, 3, 1], + [2, 3, 1] + ] + > + """ + deftransform predict(tree, data, opts \\ []) do + predict_n(tree, data, NimbleOptions.validate!(opts, @predict_schema)) + end + + defnp predict_n(tree, data, opts) do + k = opts[:k] + num_samples = Nx.axis_size(data, 0) + knn = Nx.broadcast(Nx.s64(0), {num_samples, k}) + + {knn, _} = + while {knn, {tree, data, i = Nx.s64(0)}}, i < num_samples do + curr_point = data[[i]] + k_neighbors = query_one_point(curr_point, tree, opts) + knn = Nx.put_slice(knn, [i, 0], Nx.new_axis(k_neighbors, 0)) + {knn, {tree, data, i + 1}} + end + + knn + end + + defnp sort_by_distances(distances, point_indices) do + indices = Nx.argsort(distances) + {Nx.take(distances, indices), Nx.take(point_indices, indices)} + end + + defnp compute_distance(x1, x2, opts) do + case opts[:metric] do + {:minkowski, 2} -> Distance.squared_euclidean(x1, x2) + {:minkowski, p} -> Distance.minkowski(x1, x2, p: p) + :cosine -> Distance.cosine(x1, x2) + end + end + + defnp update_knn(nearest_neighbors, distances, data, indices, curr_node, point, k, opts) do + curr_dist = compute_distance(data[[indices[curr_node]]], point, opts) + + if curr_dist < distances[[-1]] do + nearest_neighbors = + Nx.indexed_put(nearest_neighbors, Nx.new_axis(k - 1, 0), indices[curr_node]) + + distances = Nx.indexed_put(distances, Nx.new_axis(k - 1, 0), curr_dist) + sort_by_distances(distances, nearest_neighbors) + else + {distances, nearest_neighbors} + end + end + + defnp update_visited(node, visited, distances, nearest_neighbors, data, indices, point, k, opts) do + if visited[indices[node]] do + {visited, {distances, nearest_neighbors}} + else + visited = Nx.indexed_put(visited, Nx.new_axis(indices[node], 0), Nx.u8(1)) + + {distances, nearest_neighbors} = + update_knn(nearest_neighbors, distances, data, indices, node, point, k, opts) + + {visited, {distances, nearest_neighbors}} + end + end + + defnp query_one_point(point, tree, opts) do + k = opts[:k] + node = Nx.as_type(root(), :s64) + {size, dims} = Nx.shape(tree.data) + nearest_neighbors = Nx.broadcast(Nx.s64(0), {k}) + distances = Nx.broadcast(Nx.Constants.infinity(), {k}) + visited = Nx.broadcast(Nx.u8(0), {size}) + + indices = tree.indices |> Nx.as_type(:s64) + data = tree.data + + down = 0 + up = 1 + mode = down + + {nearest_neighbors, _} = + while {nearest_neighbors, + {node, data, indices, point, distances, visited, i = Nx.s64(0), mode}}, + node != -1 and i >= 0 do + coord_indicator = rem(i, dims) + + {node, i, visited, nearest_neighbors, distances, mode} = + cond do + node >= size -> + {parent(node), i - 1, visited, nearest_neighbors, distances, up} + + mode == down and + point[[coord_indicator]] < data[[indices[node], coord_indicator]] -> + {left_child(node), i + 1, visited, nearest_neighbors, distances, down} + + mode == down and + point[[coord_indicator]] >= data[[indices[node], coord_indicator]] -> + {right_child(node), i + 1, visited, nearest_neighbors, distances, down} + + mode == up -> + cond do + visited[indices[node]] -> + {parent(node), i - 1, visited, nearest_neighbors, distances, up} + + (left_child(node) >= size and right_child(node) >= size) or + (left_child(node) < size and visited[indices[left_child(node)]] and + right_child(node) < size and + visited[indices[right_child(node)]]) or + (left_child(node) < size and visited[indices[left_child(node)]] and + right_child(node) >= size) -> + {visited, {distances, nearest_neighbors}} = + update_visited( + node, + visited, + distances, + nearest_neighbors, + data, + indices, + point, + k, + opts + ) + + {parent(node), i - 1, visited, nearest_neighbors, distances, up} + + left_child(node) < size and visited[indices[left_child(node)]] and + right_child(node) < size and + not visited[indices[right_child(node)]] -> + {visited, {distances, nearest_neighbors}} = + update_visited( + node, + visited, + distances, + nearest_neighbors, + data, + indices, + point, + k, + opts + ) + + if Nx.any( + compute_distance( + point[[coord_indicator]], + data[[indices[right_child(node)], coord_indicator]], + opts + ) < + distances + ) do + {right_child(node), i + 1, visited, nearest_neighbors, distances, down} + else + {parent(node), i - 1, visited, nearest_neighbors, distances, up} + end + + ((right_child(node) < size and visited[indices[right_child(node)]]) or + right_child(node) == size) and + not visited[indices[left_child(node)]] -> + {visited, {distances, nearest_neighbors}} = + update_visited( + node, + visited, + distances, + nearest_neighbors, + data, + indices, + point, + k, + opts + ) + + if Nx.any( + compute_distance( + point[[coord_indicator]], + data[[indices[left_child(node)], coord_indicator]], + opts + ) < + distances + ) do + {left_child(node), i + 1, visited, nearest_neighbors, distances, down} + else + {parent(node), i - 1, visited, nearest_neighbors, distances, up} + end + + # Should be not reachable + true -> + {node, i, visited, nearest_neighbors, distances, mode} + end + + # Should be not reachable + true -> + {node, i, visited, nearest_neighbors, distances, mode} + end + + {nearest_neighbors, {node, data, indices, point, distances, visited, i, mode}} + end + + nearest_neighbors + end end diff --git a/test/scholar/neighbors/kd_tree_test.exs b/test/scholar/neighbors/kd_tree_test.exs index 3df22fbe..1d256375 100644 --- a/test/scholar/neighbors/kd_tree_test.exs +++ b/test/scholar/neighbors/kd_tree_test.exs @@ -1,6 +1,7 @@ defmodule Scholar.Neighbors.KDTreeTest do use ExUnit.Case, async: true - doctest Scholar.Neighbors.KDTree + alias Scholar.Neighbors.KDTree + doctest KDTree defp example do Nx.tensor([ @@ -19,54 +20,96 @@ defmodule Scholar.Neighbors.KDTreeTest do describe "unbounded" do test "sample" do - assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.unbounded(example(), compiler: EXLA) + assert %KDTree{levels: 4, indices: indices} = + KDTree.fit_unbounded(example(), compiler: EXLA) - assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + assert Nx.to_flat_list(indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end test "float" do - assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.unbounded(example() |> Nx.as_type(:f32), + assert %KDTree{levels: 4, indices: indices} = + KDTree.fit_unbounded(example() |> Nx.as_type(:f32), compiler: EXLA ) - assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + assert Nx.to_flat_list(indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end test "corner cases" do - assert %Scholar.Neighbors.KDTree{levels: 1, indexes: indexes} = - Scholar.Neighbors.KDTree.unbounded(Nx.iota({1, 2}), compiler: EXLA) + assert %KDTree{levels: 1, indices: indices} = + KDTree.fit_unbounded(Nx.iota({1, 2}), compiler: EXLA) - assert indexes == Nx.u32([0]) + assert indices == Nx.u32([0]) - assert %Scholar.Neighbors.KDTree{levels: 2, indexes: indexes} = - Scholar.Neighbors.KDTree.unbounded(Nx.iota({2, 2}), compiler: EXLA) + assert %KDTree{levels: 2, indices: indices} = + KDTree.fit_unbounded(Nx.iota({2, 2}), compiler: EXLA) - assert indexes == Nx.u32([1, 0]) + assert indices == Nx.u32([1, 0]) end end describe "bounded" do test "iota" do - assert %Scholar.Neighbors.KDTree{levels: 3, indexes: indexes} = - Scholar.Neighbors.KDTree.bounded(Nx.iota({5, 2}), 10) + assert %KDTree{levels: 3, indices: indices} = + KDTree.fit_bounded(Nx.iota({5, 2}), 10) - assert indexes == Nx.u32([3, 1, 4, 0, 2]) + assert indices == Nx.u32([3, 1, 4, 0, 2]) end test "float" do - assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.bounded(example() |> Nx.as_type(:f32), 100) + assert %KDTree{levels: 4, indices: indices} = + KDTree.fit_bounded(example() |> Nx.as_type(:f32), 100) - assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + assert Nx.to_flat_list(indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] end test "sample" do - assert %Scholar.Neighbors.KDTree{levels: 4, indexes: indexes} = - Scholar.Neighbors.KDTree.bounded(example(), 100) + assert %KDTree{levels: 4, indices: indices} = + KDTree.fit_bounded(example(), 100) - assert Nx.to_flat_list(indexes) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + assert Nx.to_flat_list(indices) == [1, 5, 9, 3, 6, 2, 8, 0, 7, 4] + end + end + + defp x do + Nx.tensor([ + [3, 6, 7, 5], + [9, 8, 5, 4], + [4, 4, 4, 1], + [9, 4, 5, 6], + [6, 4, 5, 7], + [4, 5, 3, 3], + [4, 5, 7, 8], + [9, 4, 4, 5], + [8, 4, 3, 9], + [2, 8, 4, 4] + ]) + end + + defp x_pred do + Nx.tensor([[4, 3, 8, 4], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]]) + end + + describe "predict knn" do + test "all defaults" do + kdtree = KDTree.fit_bounded(x(), 10) + + assert KDTree.predict(kdtree, x_pred()) == + Nx.tensor([[0, 6, 4], [5, 2, 9], [0, 9, 2], [5, 2, 7]]) + end + + test "metric set to {:minkowski, 1.5}" do + kdtree = KDTree.fit_bounded(x(), 10) + + assert KDTree.predict(kdtree, x_pred(), metric: {:minkowski, 1.5}) == + Nx.tensor([[0, 6, 2], [5, 2, 9], [0, 9, 2], [5, 2, 7]]) + end + + test "k set to 4" do + kdtree = KDTree.fit_bounded(x(), 10) + + assert KDTree.predict(kdtree, x_pred(), k: 4) == + Nx.tensor([[0, 6, 4, 2], [5, 2, 9, 0], [0, 9, 2, 5], [5, 2, 7, 4]]) end end end