Skip to content

Commit

Permalink
Knn via kdtree (elixir-nx#211)
Browse files Browse the repository at this point in the history
* KNN via KDTree

* Remove redundant chunk of code

* Update lib/scholar/neighbors/kd_tree.ex

* Remove redundant variable, change spelling, change handling type

* Format

* Add amendments

* Add benchmark, tests, documentation, and metric option

* Format

* Change API

* Update knn benchmark
  • Loading branch information
msluszniak authored Nov 28, 2023
1 parent 97351d3 commit 130dd3c
Show file tree
Hide file tree
Showing 3 changed files with 372 additions and 54 deletions.
27 changes: 27 additions & 0 deletions benchmarks/knn.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# mix run benchmarks/knn.exs
Nx.global_default_backend(EXLA.Backend)
Nx.Defn.global_default_options(compiler: EXLA)

key = Nx.Random.key(System.os_time())

inputs_knn = %{
"100x10" => elem(Nx.Random.uniform(key, 0, 100, shape: {100, 10}), 0),
"1000x10" => elem(Nx.Random.uniform(key, 0, 1000, shape: {1000, 10}), 0),
"10000x10" => elem(Nx.Random.uniform(key, 0, 10000, shape: {10000, 10}), 0)
}

Benchee.run(
%{
"kdtree" => fn x ->
kdtree = Scholar.Neighbors.KDTree.fit_bounded(x, Nx.axis_size(x, 0))
Scholar.Neighbors.KDTree.predict(kdtree, x, k: 4)
end,
"brute force knn" => fn x ->
model = Scholar.Neighbors.KNearestNeighbors.fit(x, Nx.broadcast(1, {Nx.axis_size(x, 0)}), num_classes: 2, num_neighbors: 4)
Scholar.Neighbors.KNearestNeighbors.k_neighbors(model, x)
end
},
time: 10,
memory_time: 2,
inputs: inputs_knn
)
Loading

0 comments on commit 130dd3c

Please sign in to comment.