Skip to content

Commit

Permalink
Rename predict_proba to predict_probability, fix a bug inside of it
Browse files Browse the repository at this point in the history
  • Loading branch information
Krsto Proroković committed May 14, 2024
1 parent 2580ebd commit 1837382
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
18 changes: 9 additions & 9 deletions lib/scholar/neighbors/knn_classifier.ex
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ defmodule Scholar.Neighbors.KNNClassifier do
"""
defn predict(model, x) do
{neighbors, distances} = compute_knn(model.algorithm, x)
labels_pred = Nx.take(model.labels, neighbors)
neighbor_labels = Nx.take(model.labels, neighbors)

case model.weights do
:uniform -> Nx.mode(labels_pred, axis: 1)
:distance -> weighted_mode(labels_pred, check_weights(distances))
:uniform -> Nx.mode(neighbor_labels, axis: 1)
:distance -> weighted_mode(neighbor_labels, check_weights(distances))
end
end

Expand All @@ -170,20 +170,20 @@ defmodule Scholar.Neighbors.KNNClassifier do
iex> y_train = Nx.tensor([0, 0, 0, 1, 1])
iex> model = Scholar.Neighbors.KNNClassifier.fit(x_train, y_train, num_neighbors: 3, num_classes: 2)
iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]])
iex> Scholar.Neighbors.KNNClassifier.predict_proba(model, x)
iex> Scholar.Neighbors.KNNClassifier.predict_probability(model, x)
Nx.tensor(
[
[1.0, 0.0],
[1.0, 0.0],
[1.0, 0.0]
[0.3333333432674408, 0.6666666865348816]
]
)
"""
defn predict_proba(model, x) do
defn predict_probability(model, x) do
num_samples = Nx.axis_size(x, 0)
{neighbors, distances} = compute_knn(model.algorithm, x)
labels_pred = Nx.take(model.labels, neighbors)
type = Nx.Type.merge(to_float_type(x), {:f, 32})
{neighbors, distances} = compute_knn(model.algorithm, x)
neighbor_labels = Nx.take(model.labels, neighbors)
proba = Nx.broadcast(Nx.tensor(0.0, type: type), {num_samples, model.num_classes})

weights =
Expand All @@ -194,7 +194,7 @@ defmodule Scholar.Neighbors.KNNClassifier do

indices =
Nx.stack(
[Nx.iota(Nx.shape(labels_pred), axis: 0), Nx.take(model.labels, labels_pred)],
[Nx.iota(Nx.shape(neighbor_labels), axis: 0), neighbor_labels],
axis: 2
)
|> Nx.flatten(axes: [0, 1])
Expand Down
18 changes: 9 additions & 9 deletions test/scholar/neighbors/knn_classifier_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ defmodule Scholar.Neighbors.KNNClassifierTest do
end
end

describe "predict_proba" do
test "predict_proba with default values" do
describe "predict_probability" do
test "predict_probability with default values" do
model = KNNClassifier.fit(x_train(), y_train(), num_classes: 2, num_neighbors: 3)
predictions = KNNClassifier.predict_proba(model, x())
predictions = KNNClassifier.predict_probability(model, x())

assert_all_close(
predictions,
Expand All @@ -150,15 +150,15 @@ defmodule Scholar.Neighbors.KNNClassifierTest do
)
end

test "predict_proba with weights set to :distance" do
test "predict_probability with weights set to :distance" do
model =
KNNClassifier.fit(x_train(), y_train(),
num_neighbors: 3,
num_classes: 2,
weights: :distance
)

predictions = KNNClassifier.predict_proba(model, x())
predictions = KNNClassifier.predict_probability(model, x())

assert_all_close(
predictions,
Expand All @@ -171,7 +171,7 @@ defmodule Scholar.Neighbors.KNNClassifierTest do
)
end

test "predict_proba with weights set to :distance and with specific metric" do
test "predict_probability with weights set to :distance and with specific metric" do
model =
KNNClassifier.fit(x_train(), y_train(),
num_classes: 2,
Expand All @@ -180,7 +180,7 @@ defmodule Scholar.Neighbors.KNNClassifierTest do
metric: {:minkowski, 1.5}
)

predictions = KNNClassifier.predict_proba(model, x())
predictions = KNNClassifier.predict_probability(model, x())

assert_all_close(
predictions,
Expand All @@ -193,7 +193,7 @@ defmodule Scholar.Neighbors.KNNClassifierTest do
)
end

test "predict_proba with weights set to :distance and with x that contains sample with zero-distance" do
test "predict_probability with weights set to :distance and with x that contains sample with zero-distance" do
x = Nx.tensor([[3, 6, 7, 5], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]])

model =
Expand All @@ -203,7 +203,7 @@ defmodule Scholar.Neighbors.KNNClassifierTest do
weights: :distance
)

predictions = KNNClassifier.predict_proba(model, x)
predictions = KNNClassifier.predict_probability(model, x)

assert_all_close(
predictions,
Expand Down

0 comments on commit 1837382

Please sign in to comment.