Skip to content

Commit

Permalink
y must be of rank 2
Browse files Browse the repository at this point in the history
  • Loading branch information
Krsto Proroković committed May 16, 2024
1 parent 9e8f35a commit 27e4f77
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
32 changes: 14 additions & 18 deletions lib/scholar/neighbors/knn_regressor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -55,43 +55,43 @@ defmodule Scholar.Neighbors.KNNRegressor do
## Examples
iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> y = Nx.tensor([1, 2, 3, 4, 5])
iex> y = Nx.tensor([[1], [2], [3], [4], [5]])
iex> model = Scholar.Neighbors.KNNRegressor.fit(x, y, num_neighbors: 3)
iex> model.algorithm
Scholar.Neighbors.BruteKNN.fit(x, num_neighbors: 3)
iex> model.labels
Nx.tensor([1, 2, 3, 4, 5])
Nx.tensor([[1], [2], [3], [4], [5]])
iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> y = Nx.tensor([1, 2, 3, 4, 5])
iex> y = Nx.tensor([[1], [2], [3], [4], [5]])
iex> model = Scholar.Neighbors.KNNRegressor.fit(x, y, algorithm: :kd_tree, num_neighbors: 3, metric: {:minkowski, 1})
iex> model.algorithm
Scholar.Neighbors.KDTree.fit(x, num_neighbors: 3, metric: {:minkowski, 1})
iex> model.labels
Nx.tensor([1, 2, 3, 4, 5])
Nx.tensor([[1], [2], [3], [4], [5]])
iex> x = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> y = Nx.tensor([1, 2, 3, 4, 5])
iex> y = Nx.tensor([[1], [2], [3], [4], [5]])
iex> key = Nx.Random.key(12)
iex> model = Scholar.Neighbors.KNNRegressor.fit(x, y, algorithm: :random_projection_forest, num_neighbors: 2, num_trees: 4, key: key)
iex> model.algorithm
Scholar.Neighbors.RandomProjectionForest.fit(x, num_neighbors: 2, num_trees: 4, key: key)
iex> model.labels
Nx.tensor([1, 2, 3, 4, 5])
Nx.tensor([[1], [2], [3], [4], [5]])
"""
deftransform fit(x, y, opts) do
if Nx.rank(x) != 2 do
raise ArgumentError,
"""
expected x to have shape {num_samples, num_features}, \
expected x to have shape {num_samples, num_features_in}, \
got tensor with shape: #{inspect(Nx.shape(x))}
"""
end

if Nx.rank(y) not in [1, 2] do
if Nx.rank(y) != 2 do
raise ArgumentError,
"""
expected y to have rank 1 or 2, \
expected y to have shape {num_samples, num_features_out}, \
got tensor with shape: #{inspect(Nx.shape(y))}
"""
end
Expand Down Expand Up @@ -137,11 +137,11 @@ defmodule Scholar.Neighbors.KNNRegressor do
## Examples
iex> x_train = Nx.tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])
iex> y_train = Nx.tensor([1, 2, 3, 4, 5])
iex> y_train = Nx.tensor([[1], [2], [3], [4], [5]])
iex> model = Scholar.Neighbors.KNNRegressor.fit(x_train, y_train, num_neighbors: 3)
iex> x = Nx.tensor([[1, 3], [4, 2], [3, 6]])
iex> Scholar.Neighbors.KNNRegressor.predict(model, x)
Nx.tensor([2.0, 2.0, 4.0])
Nx.tensor([[2.0], [2.0], [4.0]])
"""
defn predict(model, x) do
{neighbors, distances} = compute_knn(model.algorithm, x)
Expand All @@ -152,14 +152,10 @@ defmodule Scholar.Neighbors.KNNRegressor do
Nx.mean(neighbor_labels, axes: [1])

:distance ->
weights = Scholar.Neighbors.Utils.check_weights(distances)

weights =
if Nx.rank(model.labels) == 2 do
weights |> Nx.new_axis(2) |> Nx.broadcast(neighbor_labels)
else
weights
end
Scholar.Neighbors.Utils.check_weights(distances)
|> Nx.new_axis(2)
|> Nx.broadcast(neighbor_labels)

Nx.weighted_mean(neighbor_labels, weights, axes: [1])
end
Expand Down
8 changes: 4 additions & 4 deletions test/scholar/neighbors/knn_regressor_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ defmodule Scholar.Neighbors.KNNRegressorTest do
end

defp y_train do
Nx.tensor([0, 1, 1, 1, 1, 1, 1, 1, 0, 0])
Nx.tensor([[0], [1], [1], [1], [1], [1], [1], [1], [0], [0]])
end

defp x do
Expand Down Expand Up @@ -70,13 +70,13 @@ defmodule Scholar.Neighbors.KNNRegressorTest do
test "predict with default parameters" do
model = KNNRegressor.fit(x_train(), y_train(), num_neighbors: 3)
y_pred = KNNRegressor.predict(model, x())
assert_all_close(y_pred, Nx.tensor([0.66666667, 0.66666667, 0.33333333, 1.0]))
assert_all_close(y_pred, Nx.tensor([[0.66666667], [0.66666667], [0.33333333], [1.0]]))
end

test "predict with weights set to :distance" do
model = KNNRegressor.fit(x_train(), y_train(), num_neighbors: 3, weights: :distance)
y_pred = KNNRegressor.predict(model, x())
assert_all_close(y_pred, Nx.tensor([0.59648849, 0.68282796, 0.2716506, 1.0]))
assert_all_close(y_pred, Nx.tensor([[0.59648849], [0.68282796], [0.2716506], [1.0]]))
end

test "predict with cosine metric and weights set to :distance" do
Expand All @@ -88,7 +88,7 @@ defmodule Scholar.Neighbors.KNNRegressorTest do
)

y_pred = KNNRegressor.predict(model, x())
assert_all_close(y_pred, Nx.tensor([0.5736568, 0.427104, 0.33561941, 1.0]))
assert_all_close(y_pred, Nx.tensor([[0.5736568], [0.427104], [0.33561941], [1.0]]))
end

test "predict with 2D labels" do
Expand Down

0 comments on commit 27e4f77

Please sign in to comment.