Skip to content

Commit

Permalink
Add predict_proba/2
Browse files Browse the repository at this point in the history
  • Loading branch information
Krsto Proroković committed May 14, 2024
1 parent ea158f1 commit 19f1f75
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 20 deletions.
15 changes: 12 additions & 3 deletions lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,20 @@ defmodule Scholar.Neighbors.KDTree do
"""
deftransform predict(tree, data) do
if Nx.rank(data) != 2 do
raise ArgumentError, "Input data must be a 2D tensor"
raise ArgumentError,
"""
expected query tensor to have shape {num_queries, num_features}, \
got tensor with shape: #{inspect(Nx.shape(data))}
"""
end

if Nx.axis_size(data, -1) != Nx.axis_size(tree.data, -1) do
raise ArgumentError, "Input data must have the same number of features as the training data"
if Nx.axis_size(tree.data, 1) != Nx.axis_size(data, 1) do
raise ArgumentError,
"""
expected query tensor to have same number of features as tensor used to fit the tree, \
got #{inspect(Nx.axis_size(data, 1))} \
and #{inspect(Nx.axis_size(tree.data, 1))}
"""
end

predict_n(tree, data)
Expand Down
30 changes: 24 additions & 6 deletions lib/scholar/neighbors/knn_classifier.ex
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,15 @@ defmodule Scholar.Neighbors.KNNClassifier do
end

@doc """
Makes predictions using a k-NN classifier model.
Predicts classes using a k-NN classifier model.
## Examples
iex> x_train = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> y_train = Nx.tensor([0, 0, 0, 1, 1])
iex> model = Scholar.Neighbors.KNNClassifier.fit(x_train, y_train, num_neighbors: 3, num_classes: 2)
iex> x_test = Nx.tensor([[1, 3], [4, 2], [3, 6]])
iex> Scholar.Neighbors.KNNClassifier.predict(model, x_test)
iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]])
iex> Scholar.Neighbors.KNNClassifier.predict(model, x)
Nx.tensor([0, 0, 1])
"""
defn predict(model, x) do
Expand All @@ -161,6 +161,24 @@ defmodule Scholar.Neighbors.KNNClassifier do
end
end

@doc """
Predicts class probabilities using a k-NN classifier model.
## Examples
iex> x_train = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> y_train = Nx.tensor([0, 0, 0, 1, 1])
iex> model = Scholar.Neighbors.KNNClassifier.fit(x_train, y_train, num_neighbors: 3, num_classes: 2)
iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]])
iex> Scholar.Neighbors.KNNClassifier.predict_proba(model, x)
Nx.tensor(
[
[1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0]
]
)
"""
defn predict_proba(model, x) do
num_samples = Nx.axis_size(x, 0)
{neighbors, distances} = compute_knn(model.algorithm, x)
Expand All @@ -170,21 +188,21 @@ defmodule Scholar.Neighbors.KNNClassifier do

weights =
case model.weights do
:distance -> check_weights(distances)
:uniform -> Nx.broadcast(1.0, neighbors)
:distance -> check_weights(distances)
end

indices =
Nx.stack(
[Nx.iota(Nx.shape(labels_pred), axis: 0), Nx.take(model.labels, labels_pred)],
axis: -1 # TODO: Replace -1 here
axis: 2
)
|> Nx.flatten(axes: [0, 1])

proba = Nx.indexed_add(proba, indices, Nx.flatten(weights))
normalizer = Nx.sum(proba, axes: [1])
normalizer = Nx.select(normalizer == 0, 1, normalizer)
proba / Nx.new_axis(normalizer, -1) # TODO: Replace -1 here
proba / Nx.new_axis(normalizer, 1)
end

deftransformp compute_knn(algorithm, x) do
Expand Down
6 changes: 3 additions & 3 deletions lib/scholar/neighbors/random_projection_forest.ex
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,9 @@ defmodule Scholar.Neighbors.RandomProjectionForest do
if Nx.axis_size(forest.data, 1) != Nx.axis_size(query, 1) do
raise ArgumentError,
"""
expected query tensor to have the same dimension as tensor used to grow the forest, \
got #{inspect(Nx.axis_size(forest.data, 1))} \
and #{inspect(Nx.axis_size(query, 1))}
expected query tensor to have same number of features as tensor used to grow the forest, \
got #{inspect(Nx.axis_size(query, 1))} \
and #{inspect(Nx.axis_size(forest.data, 1))}
"""
end

Expand Down
99 changes: 91 additions & 8 deletions test/scholar/neighbors/knn_classifier_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ defmodule Scholar.Neighbors.KNNClassifierTest do
model =
KNNClassifier.fit(x_train(), y_train(),
algorithm: :kd_tree,
num_neighbors: 3,
num_classes: 2
num_classes: 2,
num_neighbors: 3
)

assert model.algorithm == Scholar.Neighbors.KDTree.fit(x_train(), num_neighbors: 3)
Expand All @@ -56,8 +56,8 @@ defmodule Scholar.Neighbors.KNNClassifierTest do
model =
KNNClassifier.fit(x_train(), y_train(),
algorithm: :random_projection_forest,
num_neighbors: 3,
num_classes: 2,
num_neighbors: 3,
num_trees: 4,
key: key
)
Expand Down Expand Up @@ -86,8 +86,8 @@ defmodule Scholar.Neighbors.KNNClassifierTest do
model =
KNNClassifier.fit(x_train(), y_train(),
algorithm: :kd_tree,
num_neighbors: 3,
num_classes: 2
num_classes: 2,
num_neighbors: 3
)

labels_pred = KNNClassifier.predict(model, x())
Expand All @@ -97,8 +97,8 @@ defmodule Scholar.Neighbors.KNNClassifierTest do
test "predict with weights set to :distance" do
model =
KNNClassifier.fit(x_train(), y_train(),
num_neighbors: 3,
num_classes: 2,
num_neighbors: 3,
weights: :distance
)

Expand All @@ -109,8 +109,8 @@ defmodule Scholar.Neighbors.KNNClassifierTest do
test "predict with specific metric and weights set to :distance" do
model =
KNNClassifier.fit(x_train(), y_train(),
num_neighbors: 3,
num_classes: 2,
num_neighbors: 3,
metric: {:minkowski, 1.5},
weights: :distance
)
Expand All @@ -124,13 +124,96 @@ defmodule Scholar.Neighbors.KNNClassifierTest do

model =
KNNClassifier.fit(x_train(), y_train(),
num_neighbors: 3,
num_classes: 2,
num_neighbors: 3,
weights: :distance
)

labels_pred = KNNClassifier.predict(model, x)
assert labels_pred == Nx.tensor([0, 1, 0, 1])
end
end

describe "predict_proba" do
test "predict_proba with default values" do
model = KNNClassifier.fit(x_train(), y_train(), num_classes: 2, num_neighbors: 3)
predictions = KNNClassifier.predict_proba(model, x())

assert_all_close(
predictions,
Nx.tensor([
[0.33333333, 0.66666667],
[0.33333333, 0.66666667],
[0.66666667, 0.33333333],
[0.0, 1.0]
])
)
end

test "predict_proba with weights set to :distance" do
model =
KNNClassifier.fit(x_train(), y_train(),
num_neighbors: 3,
num_classes: 2,
weights: :distance
)

predictions = KNNClassifier.predict_proba(model, x())

assert_all_close(
predictions,
Nx.tensor([
[0.40351151, 0.59648849],
[0.31717204, 0.68282796],
[0.7283494, 0.2716506],
[0.0, 1.0]
])
)
end

test "predict_proba with weights set to :distance and with specific metric" do
model =
KNNClassifier.fit(x_train(), y_train(),
num_classes: 2,
num_neighbors: 3,
weights: :distance,
metric: {:minkowski, 1.5}
)

predictions = KNNClassifier.predict_proba(model, x())

assert_all_close(
predictions,
Nx.tensor([
[0.40381038, 0.59618962],
[0.31457406, 0.68542594],
[0.72993802, 0.27006198],
[0.0, 1.0]
])
)
end

test "predict_proba with weights set to :distance and with x that contains sample with zero-distance" do
x = Nx.tensor([[3, 6, 7, 5], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]])

model =
KNNClassifier.fit(x_train(), y_train(),
num_classes: 2,
num_neighbors: 3,
weights: :distance
)

predictions = KNNClassifier.predict_proba(model, x)

assert_all_close(
predictions,
Nx.tensor([
[1.0, 0.0],
[0.31717204, 0.68282796],
[0.7283494, 0.2716506],
[0.0, 1.0]
])
)
end
end
end

0 comments on commit 19f1f75

Please sign in to comment.