diff --git a/lib/scholar/neighbors/knn_classifier.ex b/lib/scholar/neighbors/knn_classifier.ex index 4ddea1ba..70991a6f 100644 --- a/lib/scholar/neighbors/knn_classifier.ex +++ b/lib/scholar/neighbors/knn_classifier.ex @@ -153,11 +153,11 @@ defmodule Scholar.Neighbors.KNNClassifier do """ defn predict(model, x) do {neighbors, distances} = compute_knn(model.algorithm, x) - labels_pred = Nx.take(model.labels, neighbors) + neighbor_labels = Nx.take(model.labels, neighbors) case model.weights do - :uniform -> Nx.mode(labels_pred, axis: 1) - :distance -> weighted_mode(labels_pred, check_weights(distances)) + :uniform -> Nx.mode(neighbor_labels, axis: 1) + :distance -> weighted_mode(neighbor_labels, check_weights(distances)) end end @@ -170,20 +170,20 @@ defmodule Scholar.Neighbors.KNNClassifier do iex> y_train = Nx.tensor([0, 0, 0, 1, 1]) iex> model = Scholar.Neighbors.KNNClassifier.fit(x_train, y_train, num_neighbors: 3, num_classes: 2) iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]]) - iex> Scholar.Neighbors.KNNClassifier.predict_proba(model, x) + iex> Scholar.Neighbors.KNNClassifier.predict_probability(model, x) Nx.tensor( [ [1.0, 0.0], [1.0, 0.0], - [1.0, 0.0] + [0.3333333432674408, 0.6666666865348816] ] ) """ - defn predict_proba(model, x) do + defn predict_probability(model, x) do num_samples = Nx.axis_size(x, 0) - {neighbors, distances} = compute_knn(model.algorithm, x) - labels_pred = Nx.take(model.labels, neighbors) type = Nx.Type.merge(to_float_type(x), {:f, 32}) + {neighbors, distances} = compute_knn(model.algorithm, x) + neighbor_labels = Nx.take(model.labels, neighbors) proba = Nx.broadcast(Nx.tensor(0.0, type: type), {num_samples, model.num_classes}) weights = @@ -194,7 +194,7 @@ defmodule Scholar.Neighbors.KNNClassifier do indices = Nx.stack( - [Nx.iota(Nx.shape(labels_pred), axis: 0), Nx.take(model.labels, labels_pred)], + [Nx.iota(Nx.shape(neighbor_labels), axis: 0), neighbor_labels], axis: 2 ) |> Nx.flatten(axes: [0, 1]) diff --git a/test/scholar/neighbors/knn_classifier_test.exs b/test/scholar/neighbors/knn_classifier_test.exs index 2c08311d..4df48d77 100644 --- a/test/scholar/neighbors/knn_classifier_test.exs +++ b/test/scholar/neighbors/knn_classifier_test.exs @@ -134,10 +134,10 @@ defmodule Scholar.Neighbors.KNNClassifierTest do end end - describe "predict_proba" do - test "predict_proba with default values" do + describe "predict_probability" do + test "predict_probability with default values" do model = KNNClassifier.fit(x_train(), y_train(), num_classes: 2, num_neighbors: 3) - predictions = KNNClassifier.predict_proba(model, x()) + predictions = KNNClassifier.predict_probability(model, x()) assert_all_close( predictions, @@ -150,7 +150,7 @@ defmodule Scholar.Neighbors.KNNClassifierTest do ) end - test "predict_proba with weights set to :distance" do + test "predict_probability with weights set to :distance" do model = KNNClassifier.fit(x_train(), y_train(), num_neighbors: 3, @@ -158,7 +158,7 @@ defmodule Scholar.Neighbors.KNNClassifierTest do weights: :distance ) - predictions = KNNClassifier.predict_proba(model, x()) + predictions = KNNClassifier.predict_probability(model, x()) assert_all_close( predictions, @@ -171,7 +171,7 @@ defmodule Scholar.Neighbors.KNNClassifierTest do ) end - test "predict_proba with weights set to :distance and with specific metric" do + test "predict_probability with weights set to :distance and with specific metric" do model = KNNClassifier.fit(x_train(), y_train(), num_classes: 2, @@ -180,7 +180,7 @@ defmodule Scholar.Neighbors.KNNClassifierTest do metric: {:minkowski, 1.5} ) - predictions = KNNClassifier.predict_proba(model, x()) + predictions = KNNClassifier.predict_probability(model, x()) assert_all_close( predictions, @@ -193,7 +193,7 @@ defmodule Scholar.Neighbors.KNNClassifierTest do ) end - test "predict_proba with weights set to :distance and with x that contains sample with zero-distance" do + test "predict_probability with weights set to :distance and with x that contains sample with zero-distance" do x = Nx.tensor([[3, 6, 7, 5], [1, 6, 1, 1], [3, 7, 9, 2], [5, 2, 1, 2]]) model = @@ -203,7 +203,7 @@ defmodule Scholar.Neighbors.KNNClassifierTest do weights: :distance ) - predictions = KNNClassifier.predict_proba(model, x) + predictions = KNNClassifier.predict_probability(model, x) assert_all_close( predictions,