Skip to content

Commit

Permalink
Split RadiusNearestNeighbors module into RNNClassifier and RNNRegress…
Browse files Browse the repository at this point in the history
…or (#296)
  • Loading branch information
norm4nn authored Sep 10, 2024
1 parent 2dca4aa commit 1765930
Show file tree
Hide file tree
Showing 6 changed files with 467 additions and 230 deletions.
4 changes: 2 additions & 2 deletions lib/scholar/cluster/dbscan.ex
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ defmodule Scholar.Cluster.DBSCAN do
y_dummy = Nx.broadcast(Nx.tensor(0), {num_samples})

neighbor_model =
Scholar.Neighbors.RadiusNearestNeighbors.fit(x, y_dummy,
Scholar.Neighbors.RNNClassifier.fit(x, y_dummy,
num_classes: 1,
radius: opts[:eps],
metric: opts[:metric]
)

{_dist, indices} =
Scholar.Neighbors.RadiusNearestNeighbors.radius_neighbors(neighbor_model, x)
Scholar.Neighbors.RNNClassifier.radius_neighbors(neighbor_model, x)

n_neighbors = Nx.sum(indices * weights, axes: [1])
core_samples = n_neighbors >= opts[:min_samples]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
defmodule Scholar.Neighbors.RadiusNearestNeighbors do
defmodule Scholar.Neighbors.RNNClassifier do
@moduledoc """
The Radius Nearest Neighbors.
It implements both classification and regression.
It implements classification.
"""
import Nx.Defn
import Scholar.Shared
require Nx

@derive {Nx.Container,
keep: [:weights, :num_classes, :task, :metric, :radius], containers: [:data, :labels]}
defstruct [:data, :labels, :weights, :num_classes, :task, :metric, :radius]
keep: [:weights, :num_classes, :metric, :radius], containers: [:data, :labels]}
defstruct [:data, :labels, :weights, :num_classes, :metric, :radius]

opts = [
radius: [
Expand All @@ -20,6 +20,7 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
],
num_classes: [
type: :pos_integer,
required: true,
doc: "Number of classes in provided labels"
],
weights: [
Expand Down Expand Up @@ -47,18 +48,6 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
* Anonymous function of arity 2 that takes two rank-2 tensors.
"""
],
task: [
type: {:in, [:classification, :regression]},
default: :classification,
doc: """
Task that will be performed using Radius Nearest Neighbors. Possible values:
* `:classification` - Classifier implementing the Radius Nearest Neighbors vote.
* `:regression` - Regression based on Radius Nearest Neighbors.
The target is predicted by local interpolation of the targets associated of the nearest neighbors in the training set.
"""
]
]

Expand All @@ -70,8 +59,6 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
For classification, provided labels need to be consecutive non-negative integers. If your labels does
not meet this condition please use `Scholar.Preprocessing.ordinal_encode`
Currently 2D labels are only supported for regression tasks.
## Options
#{NimbleOptions.docs(@opts_schema)}
Expand All @@ -88,10 +75,6 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
* `:num_classes` - Number of classes in provided labels.
* `:task` - Task that will be performed using Radius Nearest Neighbors.
For `:classification` task, model will be a classifier implementing the Radius Nearest Neighbors vote.
For `:regression` task, model is a regressor based on Radius Nearest Neighbors.
* `:metric` - The metric function used.
* `:radius` - Radius of neighborhood.
Expand All @@ -100,22 +83,24 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
iex> x = Nx.tensor([[1, 2], [2, 4], [1, 3], [2, 5]])
iex> y = Nx.tensor([1, 0, 1, 1])
iex> Scholar.Neighbors.RadiusNearestNeighbors.fit(x, y, num_classes: 2)
%Scholar.Neighbors.RadiusNearestNeighbors{
data: Nx.tensor(
iex> Scholar.Neighbors.RNNClassifier.fit(x, y, num_classes: 2)
%Scholar.Neighbors.RNNClassifier{
data: #Nx.Tensor<
s64[4][2]
[
[1, 2],
[2, 4],
[1, 3],
[2, 5]
]
),
labels: Nx.tensor(
>,
labels: #Nx.Tensor<
s64[4]
[1, 0, 1, 1]
),
>,
weights: :uniform,
num_classes: 2,
task: :classification,
metric: &Scholar.Metrics.Distance.pairwise_minkowski/2,
radius: 1.0
}
Expand Down Expand Up @@ -144,17 +129,11 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do

opts = NimbleOptions.validate!(opts, @opts_schema)

if opts[:num_classes] == nil and opts[:task] == :classification do
raise ArgumentError,
"expected :num_classes to be provided for task :classification"
end

%__MODULE__{
data: x,
labels: y,
weights: opts[:weights],
num_classes: opts[:num_classes],
task: opts[:task],
metric: opts[:metric],
radius: opts[:radius]
}
Expand All @@ -171,54 +150,17 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
iex> x = Nx.tensor([[1, 2], [2, 4], [1, 3], [2, 5]])
iex> y = Nx.tensor([1, 0, 1, 1])
iex> model = Scholar.Neighbors.RadiusNearestNeighbors.fit(x, y, num_classes: 2)
iex> Scholar.Neighbors.RadiusNearestNeighbors.predict(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
Nx.tensor(
iex> model = Scholar.Neighbors.RNNClassifier.fit(x, y, num_classes: 2)
iex> Scholar.Neighbors.RNNClassifier.predict(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
#Nx.Tensor<
s64[2]
[0, 1]
)
>
"""
defn predict(%__MODULE__{labels: labels, weights: weights, task: task} = model, x) do
case task do
:classification ->
{probabilities, outliers_mask} = predict_probability(model, x)
results = Nx.argmax(probabilities, axis: 1)
Nx.select(outliers_mask, -1, results)

:regression ->
{distances, indices} = radius_neighbors(model, x)

x_num_samples = Nx.axis_size(x, 0)
train_num_samples = Nx.axis_size(labels, 0)
labels_rank = Nx.rank(labels)

labels =
if labels_rank == 1 do
Nx.new_axis(labels, 0) |> Nx.broadcast({x_num_samples, train_num_samples})
else
out_size = Nx.axis_size(labels, 1)
Nx.new_axis(labels, 0) |> Nx.broadcast({x_num_samples, train_num_samples, out_size})
end

indices =
if labels_rank == 2,
do: Nx.new_axis(indices, -1) |> Nx.broadcast(labels),
else: indices

case weights do
:distance ->
weights = check_weights(distances)

weights =
if labels_rank == 2,
do: Nx.new_axis(weights, -1) |> Nx.broadcast(labels),
else: weights

Nx.weighted_mean(labels, indices * weights, axes: [1])

:uniform ->
Nx.weighted_mean(labels, indices, axes: [1])
end
end
defn predict(model, x) do
{probabilities, outliers_mask} = predict_probability(model, x)
results = Nx.argmax(probabilities, axis: 1)
Nx.select(outliers_mask, -1, results)
end

@doc """
Expand All @@ -234,61 +176,29 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
iex> x = Nx.tensor([[1, 2], [2, 4], [1, 3], [2, 5]])
iex> y = Nx.tensor([1, 0, 1, 1])
iex> model = Scholar.Neighbors.RadiusNearestNeighbors.fit(x, y, num_classes: 2)
iex> Scholar.Neighbors.RadiusNearestNeighbors.predict_probability(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
{Nx.tensor(
iex> model = Scholar.Neighbors.RNNClassifier.fit(x, y, num_classes: 2)
iex> Scholar.Neighbors.RNNClassifier.predict_probability(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
{#Nx.Tensor<
f32[2][2]
[
[0.5, 0.5],
[0.0, 1.0]
]
),
Nx.tensor(
[0, 0], type: :u8
)}
"""
deftransform predict_probability(%__MODULE__{task: :classification} = model, x) do
predict_proba_n(model, x)
end

@doc """
Find the Radius neighbors of a point.
## Return Values
Returns indices of the selected neighbor points as a mask (1 if a point is a neighbor, 0 otherwise) and their respective distances.
## Examples
iex> x = Nx.tensor([[1, 2], [2, 4], [1, 3], [2, 5]])
iex> y = Nx.tensor([1, 0, 1, 1])
iex> model = Scholar.Neighbors.RadiusNearestNeighbors.fit(x, y, num_classes: 2)
iex> Scholar.Neighbors.RadiusNearestNeighbors.radius_neighbors(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
{Nx.tensor(
[
[2.469818353652954, 0.3162313997745514, 1.5811394453048706, 0.7071067690849304],
[0.10000114142894745, 2.1931710243225098, 1.0049877166748047, 3.132091760635376]
]
),
Nx.tensor(
[
[0, 1, 0, 1],
[1, 0, 0, 0]
], type: :u8
)}
>,
#Nx.Tensor<
u8[2]
[0, 0]
>}
"""
defn radius_neighbors(%__MODULE__{metric: metric, radius: radius, data: data}, x) do
distances = metric.(x, data)
{distances, distances <= radius}
end

defnp predict_proba_n(
%__MODULE__{
labels: labels,
weights: weights,
num_classes: num_classes
} = model,
x
) do
defn predict_probability(
%__MODULE__{
labels: labels,
weights: weights,
num_classes: num_classes
} = model,
x
) do
{distances, indices} = radius_neighbors(model, x)
num_samples = Nx.axis_size(x, 0)
outliers_mask = Nx.sum(indices, axes: [1]) == 0
Expand All @@ -315,6 +225,40 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
{final_probabilities / Nx.new_axis(normalizer, -1), outliers_mask}
end

@doc """
Find the Radius neighbors of a point.
## Return Values
Returns indices of the selected neighbor points as a mask (1 if a point is a neighbor, 0 otherwise) and their respective distances.
## Examples
iex> x = Nx.tensor([[1, 2], [2, 4], [1, 3], [2, 5]])
iex> y = Nx.tensor([1, 0, 1, 1])
iex> model = Scholar.Neighbors.RNNClassifier.fit(x, y, num_classes: 2)
iex> Scholar.Neighbors.RNNClassifier.radius_neighbors(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
{#Nx.Tensor<
f32[2][4]
[
[2.469818353652954, 0.3162313997745514, 1.5811394453048706, 0.7071067690849304],
[0.10000114142894745, 2.1931710243225098, 1.0049877166748047, 3.132091760635376]
]
>,
#Nx.Tensor<
u8[2][4]
[
[0, 1, 0, 1],
[1, 0, 0, 0]
]
>}
"""
defn radius_neighbors(%__MODULE__{metric: metric, radius: radius, data: data}, x) do
distances = metric.(x, data)
{distances, distances <= radius}
end

defnp check_weights(weights) do
zero_mask = weights == 0
zero_rows = zero_mask |> Nx.any(axes: [1], keep_axes: true) |> Nx.broadcast(weights)
Expand Down
Loading

0 comments on commit 1765930

Please sign in to comment.