Skip to content

Commit

Permalink
Unify neighbors metrics
Browse files Browse the repository at this point in the history
Closes #265.
  • Loading branch information
josevalim committed May 14, 2024
1 parent ffaac87 commit f358b24
Show file tree
Hide file tree
Showing 10 changed files with 118 additions and 134 deletions.
10 changes: 6 additions & 4 deletions lib/scholar/cluster/dbscan.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@ defmodule Scholar.Cluster.DBSCAN do
type: :integer
],
metric: [
type: {:custom, Scholar.Options, :metric, []},
default: {:minkowski, 2},
type: {:custom, Scholar.Neighbors.Utils, :pairwise_metric, []},
default: &Scholar.Metrics.Distance.pairwise_minkowski/2,
doc: ~S"""
Name of the metric. Possible values:
The function that measures the pairwise distance between two points. Possible values:
* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`)
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
* `:cosine` - Cosine metric.
* Anonymous function of arity 2 that takes two rank-2 tensors.
"""
],
weights: [
Expand Down
68 changes: 49 additions & 19 deletions lib/scholar/metrics/distance.ex
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ defmodule Scholar.Metrics.Distance do
]
]

minkowski_schema =
general_schema ++
[
p: [
type: {:or, [{:custom, Scholar.Options, :positive_number, []}, {:in, [:infinity]}]},
default: 2.0,
doc: """
A positive parameter of Minkowski distance or :infinity (then Chebyshev metric computed).
"""
]
pairwise_minkowski_schema =
[
p: [
type: {:or, [{:custom, Scholar.Options, :positive_number, []}, {:in, [:infinity]}]},
default: 2.0,
doc: """
A positive parameter of Minkowski distance or :infinity (then Chebyshev metric computed).
"""
]
]

@general_schema NimbleOptions.new!(general_schema)
@minkowski_schema NimbleOptions.new!(minkowski_schema)
@minkowski_schema NimbleOptions.new!(general_schema ++ pairwise_minkowski_schema)
@pairwise_minkowski_schema NimbleOptions.new!(pairwise_minkowski_schema)

@doc """
Standard euclidean distance ($L_{2}$ distance).
Expand Down Expand Up @@ -361,6 +361,40 @@ defmodule Scholar.Metrics.Distance do
end
end

@doc """
Computes the pairwise Minkowski distance.
## Examples
iex> x = Nx.iota({2, 3})
iex> y = Nx.reverse(x)
iex> Scholar.Metrics.Distance.pairwise_minkowski(x, y)
#Nx.Tensor<
f32[2][2]
[
[5.916079998016357, 2.8284270763397217],
[2.8284270763397217, 5.916079998016357]
]
>
"""
deftransform pairwise_minkowski(x, y, opts \\ []) do
pairwise_minkowski_n(x, y, NimbleOptions.validate!(opts, @pairwise_minkowski_schema))
end

defnp pairwise_minkowski_n(x, y, opts) do
p = opts[:p]

cond do
p == 2 ->
pairwise_euclidean(x, y)

true ->
x = Nx.new_axis(x, 1)
y = Nx.new_axis(y, 0)
minkowski_n(x, y, axes: [-1], p: p)
end
end

@doc """
Cosine distance.
Expand Down Expand Up @@ -607,18 +641,14 @@ defmodule Scholar.Metrics.Distance do
## Examples
iex> x = Nx.iota({6, 6})
iex> x = Nx.iota({2, 3})
iex> y = Nx.reverse(x)
iex> Scholar.Metrics.Distance.pairwise_euclidean(x, y)
#Nx.Tensor<
f32[6][6]
f32[2][2]
[
[73.9594497680664, 59.380130767822266, 44.87761306762695, 30.561412811279297, 16.911535263061523, 8.366600036621094],
[59.380130767822266, 44.87761306762695, 30.561412811279297, 16.911535263061523, 8.366600036621094, 16.911535263061523],
[44.87761306762695, 30.561412811279297, 16.911535263061523, 8.366600036621094, 16.911535263061523, 30.561412811279297],
[30.561412811279297, 16.911535263061523, 8.366600036621094, 16.911535263061523, 30.561412811279297, 44.87761306762695],
[16.911535263061523, 8.366600036621094, 16.911535263061523, 30.561412811279297, 44.87761306762695, 59.380130767822266],
[8.366600036621094, 16.911535263061523, 30.561412811279297, 44.87761306762695, 59.380130767822266, 73.9594497680664]
[5.916079998016357, 2.8284270763397217],
[2.8284270763397217, 5.916079998016357]
]
>
"""
Expand Down
28 changes: 6 additions & 22 deletions lib/scholar/neighbors/brute_knn.ex
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,17 @@ defmodule Scholar.Neighbors.BruteKNN do
doc: "The number of nearest neighbors."
],
metric: [
type: {:or, [{:custom, Scholar.Options, :metric, []}, {:fun, 2}]},
default: {:minkowski, 2},
type: {:custom, Scholar.Neighbors.Utils, :pairwise_metric, []},
default: &Scholar.Metrics.Distance.pairwise_minkowski/2,
doc: ~S"""
The function that measures the distance between two points. Possible values:
The function that measures the pairwise distance between two points. Possible values:
* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`)
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
* `:cosine` - Cosine metric.
* Anonymous function of arity 2 that takes two rank-1 tensors of same dimension and returns a scalar.
* Anonymous function of arity 2 that takes two rank-2 tensors.
"""
],
batch_size: [
Expand Down Expand Up @@ -86,21 +86,9 @@ defmodule Scholar.Neighbors.BruteKNN do
"""
end

metric =
case opts[:metric] do
{:minkowski, p} ->
&Scholar.Metrics.Distance.minkowski(&1, &2, p: p)

:cosine ->
&Scholar.Metrics.Distance.cosine/2

fun when is_function(fun, 2) ->
fun
end

%__MODULE__{
num_neighbors: k,
metric: metric,
metric: opts[:metric],
data: data,
batch_size: opts[:batch_size]
}
Expand Down Expand Up @@ -245,11 +233,7 @@ defmodule Scholar.Neighbors.BruteKNN do
defnp brute_force_search(data, query, opts) do
k = opts[:num_neighbors]
metric = opts[:metric]
{m, d} = Nx.shape(data)
n = Nx.axis_size(query, 0)
x = query |> Nx.new_axis(1) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data])
y = data |> Nx.new_axis(0) |> Nx.broadcast({n, m, d}) |> Nx.vectorize([:query, :data])
distances = metric.(x, y) |> Nx.devectorize() |> Nx.rename(nil)
distances = metric.(query, data)

neighbor_indices =
Nx.argsort(distances, axis: 1, type: :u64) |> Nx.slice_along_axis(0, k, axis: 1)
Expand Down
39 changes: 11 additions & 28 deletions lib/scholar/neighbors/k_nearest_neighbors.ex
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,17 @@ defmodule Scholar.Neighbors.KNearestNeighbors do
"""
],
metric: [
type: {:custom, Scholar.Options, :metric, []},
default: {:minkowski, 2},
type: {:custom, Scholar.Neighbors.Utils, :pairwise_metric, []},
default: &Scholar.Metrics.Distance.pairwise_minkowski/2,
doc: ~S"""
Name of the metric. Possible values:
The function that measures the pairwise distance between two points. Possible values:
* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`)
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
* `:cosine` - Cosine metric.
* Anonymous function of arity 2 that takes two rank-2 tensors.
"""
],
task: [
Expand Down Expand Up @@ -119,7 +121,7 @@ defmodule Scholar.Neighbors.KNearestNeighbors do
weights: :uniform,
num_classes: 2,
task: :classification,
metric: {:minkowski, 2}
metric: &Scholar.Metrics.Distance.pairwise_minkowski/2
}
"""
deftransform fit(x, y, opts \\ []) do
Expand Down Expand Up @@ -287,8 +289,8 @@ defmodule Scholar.Neighbors.KNearestNeighbors do
iex> Scholar.Neighbors.KNearestNeighbors.k_neighbors(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
{Nx.tensor(
[
[0.3162279427051544, 0.7071065902709961, 1.5811389684677124, 2.469817876815796],
[0.10000002384185791, 1.0049875974655151, 2.193171262741089, 3.132091760635376]
[0.3162313997745514, 0.7071067690849304, 1.5811394453048706, 2.469818353652954],
[0.10000114142894745, 1.0049877166748047, 2.1931710243225098, 3.132091760635376]
]
),
Nx.tensor(
Expand All @@ -306,27 +308,8 @@ defmodule Scholar.Neighbors.KNearestNeighbors do
} = _model,
x
) do
{num_samples, num_features} = Nx.shape(data)
{num_samples_x, _num_features} = Nx.shape(x)
broadcast_shape = {num_samples_x, num_samples, num_features}
data_broadcast = Nx.new_axis(data, 0) |> Nx.broadcast(broadcast_shape)
x_broadcast = Nx.new_axis(x, 1) |> Nx.broadcast(broadcast_shape)

dist =
case metric do
{:minkowski, p} ->
Scholar.Metrics.Distance.minkowski(
data_broadcast,
x_broadcast,
axes: [-1],
p: p
)

:cosine ->
Scholar.Metrics.Distance.pairwise_cosine(x, data)
end

{val, ind} = Nx.top_k(-dist, k: default_num_neighbors)
distances = metric.(x, data)
{val, ind} = Nx.top_k(-distances, k: default_num_neighbors)
{-val, ind}
end

Expand Down
21 changes: 7 additions & 14 deletions lib/scholar/neighbors/kd_tree.ex
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ defmodule Scholar.Neighbors.KDTree do
doc: "The number of neighbors to use by default for `k_neighbors` queries"
],
metric: [
type: {:custom, Scholar.Options, :metric, []},
default: {:minkowski, 2},
type: {:custom, Scholar.Neighbors.Utils, :metric, []},
default: &Scholar.Metrics.Distance.minkowski/2,
doc: ~S"""
Name of the metric. Possible values:
The function that measures the distance between two points. Possible values:
* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`)
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
* `:cosine` - Cosine metric.
* Anonymous function of arity 2 that takes two rank-1 tensors and returns a scalar.
"""
]
]
Expand Down Expand Up @@ -70,21 +72,12 @@ defmodule Scholar.Neighbors.KDTree do
deftransform fit(tensor, opts \\ []) do
opts = NimbleOptions.validate!(opts, @opts_schema)

metric =
case opts[:metric] do
{:minkowski, p} ->
&Scholar.Metrics.Distance.minkowski(&1, &2, p: p)

:cosine ->
&Scholar.Metrics.Distance.pairwise_cosine/2
end

%__MODULE__{
levels: levels(tensor),
indices: fit_n(tensor),
data: tensor,
num_neighbors: opts[:num_neighbors],
metric: metric
metric: opts[:metric]
}
end

Expand Down
48 changes: 13 additions & 35 deletions lib/scholar/neighbors/radius_nearest_neighbors.ex
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
"""
],
metric: [
type: {:custom, Scholar.Options, :metric, []},
default: {:minkowski, 2},
type: {:custom, Scholar.Neighbors.Utils, :pairwise_metric, []},
default: &Scholar.Metrics.Distance.pairwise_minkowski/2,
doc: ~S"""
Name of the metric. Possible values:
The function that measures the pairwise distance between two points. Possible values:
* `{:minkowski, p}` - Minkowski metric. By changing value of `p` parameter (a positive number or `:infinity`)
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
we can set Manhattan (`1`), Euclidean (`2`), Chebyshev (`:infinity`), or any arbitrary $L_p$ metric.
* `:cosine` - Cosine metric.
* Anonymous function of arity 2 that takes two rank-2 tensors.
"""
],
task: [
Expand Down Expand Up @@ -90,7 +92,7 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
For `:classification` task, model will be a classifier implementing the Radius Nearest Neighbors vote.
For `:regression` task, model is a regressor based on Radius Nearest Neighbors.
* `:metric` - Name of the metric.
* `:metric` - The metric function used.
* `:radius` - Radius of neighborhood.
Expand All @@ -114,7 +116,7 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
weights: :uniform,
num_classes: 2,
task: :classification,
metric: {:minkowski, 2},
metric: &Scholar.Metrics.Distance.pairwise_minkowski/2,
radius: 1.0
}
"""
Expand Down Expand Up @@ -244,12 +246,7 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
[0, 0], type: :u8
)}
"""
deftransform predict_proba(
%__MODULE__{
task: :classification
} = model,
x
) do
deftransform predict_proba(%__MODULE__{task: :classification} = model, x) do
predict_proba_n(model, x)
end

Expand All @@ -268,8 +265,8 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
iex> Scholar.Neighbors.RadiusNearestNeighbors.radius_neighbors(model, Nx.tensor([[1.9, 4.3], [1.1, 2.0]]))
{Nx.tensor(
[
[2.469817876815796, 0.3162279427051544, 1.5811389684677124, 0.7071065902709961],
[0.10000002384185791, 2.193171262741089, 1.0049875974655151, 3.132091760635376]
[2.469818353652954, 0.3162313997745514, 1.5811394453048706, 0.7071067690849304],
[0.10000114142894745, 2.1931710243225098, 1.0049877166748047, 3.132091760635376]
]
),
Nx.tensor(
Expand All @@ -280,27 +277,8 @@ defmodule Scholar.Neighbors.RadiusNearestNeighbors do
)}
"""
defn radius_neighbors(%__MODULE__{metric: metric, radius: radius, data: data}, x) do
{num_samples, num_features} = Nx.shape(data)
{num_samples_x, _num_features} = Nx.shape(x)
broadcast_shape = {num_samples_x, num_samples, num_features}
data_broadcast = Nx.new_axis(data, 0) |> Nx.broadcast(broadcast_shape)
x_broadcast = Nx.new_axis(x, 1) |> Nx.broadcast(broadcast_shape)

dist =
case metric do
{:minkowski, p} ->
Scholar.Metrics.Distance.minkowski(
data_broadcast,
x_broadcast,
axes: [-1],
p: p
)

:cosine ->
Scholar.Metrics.Distance.pairwise_cosine(x, data)
end

{dist, dist <= radius}
distances = metric.(x, data)
{distances, distances <= radius}
end

defnp predict_proba_n(
Expand Down
Loading

0 comments on commit f358b24

Please sign in to comment.