Skip to content

Commit

Permalink
Pairwise distances (elixir-nx#177)
Browse files Browse the repository at this point in the history
* Update mix.installs

* Add pairwise disctances

* Use pairwise distance

* Format
  • Loading branch information
msluszniak authored Oct 3, 2023
1 parent 8371df9 commit 07405d6
Show file tree
Hide file tree
Showing 8 changed files with 264 additions and 152 deletions.
8 changes: 5 additions & 3 deletions lib/scholar/cluster/affinity_propagation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,10 @@ defmodule Scholar.Cluster.AffinityPropagation do
Nx.new_axis(cluster_centers, 0) |> Nx.broadcast(broadcast_shape),
axes: [-1]
)

dist = Scholar.Metrics.Distance.pairwise_euclidean(x, cluster_centers)

Nx.select(Nx.is_nan(dist), Nx.Constants.infinity(Nx.type(dist)), dist)
|> Nx.argmin(axis: 1)
end

Expand All @@ -311,9 +315,7 @@ defmodule Scholar.Cluster.AffinityPropagation do
n = Nx.axis_size(data, 0)
self_preference = opts[:self_preference]

norm1 = Nx.sum(data ** 2, axes: [1], keep_axes: true)
norm2 = Nx.transpose(norm1)
dist = -1 * (norm1 + norm2 - 2 * Nx.dot(data, [1], data, [1]))
dist = -Scholar.Metrics.Distance.pairwise_squared_euclidean(data)

fill_in =
cond do
Expand Down
28 changes: 2 additions & 26 deletions lib/scholar/cluster/k_means.ex
Original file line number Diff line number Diff line change
Expand Up @@ -300,24 +300,8 @@ defmodule Scholar.Cluster.KMeans do
"""
defn predict(%__MODULE__{clusters: clusters} = _model, x) do
assert_same_shape!(x[0], clusters[0])
{num_clusters, _} = Nx.shape(clusters)
{num_samples, num_features} = Nx.shape(x)

clusters =
clusters
|> Nx.new_axis(1)
|> Nx.broadcast({num_clusters, num_samples, num_features})
|> Nx.reshape({num_clusters * num_samples, num_features})

inertia_for_centroids =
Scholar.Metrics.Distance.squared_euclidean(
Nx.tile(x, [num_clusters, 1]),
clusters,
axes: [1]
)
|> Nx.reshape({num_clusters, num_samples})

inertia_for_centroids |> Nx.argmin(axis: 0)
Scholar.Metrics.Distance.pairwise_squared_euclidean(clusters, x) |> Nx.argmin(axis: 0)
end

@doc """
Expand All @@ -343,14 +327,6 @@ defmodule Scholar.Cluster.KMeans do
)
"""
defn transform(%__MODULE__{clusters: clusters} = _model, x) do
{num_clusters, num_features} = Nx.shape(clusters)
{num_samples, _} = Nx.shape(x)
broadcast_shape = {num_samples, num_clusters, num_features}

Scholar.Metrics.Distance.euclidean(
Nx.new_axis(x, 1) |> Nx.broadcast(broadcast_shape),
Nx.new_axis(clusters, 0) |> Nx.broadcast(broadcast_shape),
axes: [-1]
)
Scholar.Metrics.Distance.pairwise_euclidean(x, clusters)
end
end
40 changes: 22 additions & 18 deletions lib/scholar/manifold/tsne.ex
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,9 @@ defmodule Scholar.Manifold.TSNE do
end

defnp fit_n(x, key, opts \\ []) do
{perplexity, learning_rate, num_iters, num_components, exaggeration, init, metric} =
{perplexity, learning_rate, num_iters, num_components, exaggeration, init} =
{opts[:perplexity], opts[:learning_rate], opts[:num_iters], opts[:num_components],
opts[:exaggeration], opts[:init], opts[:metric]}
opts[:exaggeration], opts[:init]}

x = to_float(x)
{n, _dims} = Nx.shape(x)
Expand All @@ -138,22 +138,22 @@ defmodule Scholar.Manifold.TSNE do
x_embedded / Nx.standard_deviation(x_embedded[[.., 0]]) * 1.0e-4
end

p = p_joint(x, perplexity, metric)
p = p_joint(x, perplexity, opts)

{y, _} =
while {y1, {y2 = y1, learning_rate, p}},
i <- 2..(num_iters - 1),
unroll: opts[:learning_loop_unroll] do
q = q_joint(y1, metric)
grad = gradient(p * exaggeration(i, exaggeration), q, y1, metric)
q = q_joint(y1, opts)
grad = gradient(p * exaggeration(i, exaggeration), q, y1, opts)
y_next = y1 - learning_rate * grad + momentum(i) * (y1 - y2)
{y_next, {y1, learning_rate, p}}
end

y
end

defnp pairwise_dist(x, metric) do
defn pairwise_dist(x, opts) do
{num_samples, num_features} = Nx.shape(x)
broadcast_shape = {num_samples, num_samples, num_features}

Expand All @@ -167,18 +167,18 @@ defmodule Scholar.Manifold.TSNE do
|> Nx.reshape({num_samples, 1, num_features})
|> Nx.broadcast(broadcast_shape)

case metric do
case opts[:metric] do
:squared_euclidean ->
Distance.squared_euclidean(t1, t2, axes: [2])
Distance.pairwise_squared_euclidean(x)

:euclidean ->
Distance.euclidean(t1, t2, axes: [2])
Distance.pairwise_euclidean(x)

:manhattan ->
Distance.manhattan(t1, t2, axes: [2])

:cosine ->
Distance.cosine(t1, t2, axes: [2])
Distance.pairwise_cosine(x)

:chebyshev ->
Distance.chebyshev(t1, t2, axes: [2])
Expand Down Expand Up @@ -239,8 +239,12 @@ defmodule Scholar.Manifold.TSNE do
low = low,
high = high,
{max_iters, tol,
perplexity_val = Nx.Constants.infinity(to_float_type(target_perplexity)),
distances, target_perplexity, i = 0}
perplexity_val =
Nx.Constants.infinity(
Nx.Type.to_floating(
Nx.Type.merge(Nx.type(target_perplexity), Nx.type(distances))
)
), distances, target_perplexity, i = 0}
},
i < max_iters and Nx.abs(perplexity_val - target_perplexity) > tol do
mid = (low + high) / 2
Expand All @@ -261,18 +265,18 @@ defmodule Scholar.Manifold.TSNE do
(high + low) / 2
end

defnp q_joint(y, metric) do
distances = pairwise_dist(y, metric)
defnp q_joint(y, opts) do
distances = pairwise_dist(y, opts)
n = Nx.axis_size(distances, 0)
inv_distances = 1 / (1 + distances)
inv_distances = inv_distances / Nx.sum(inv_distances)
Nx.put_diagonal(inv_distances, Nx.broadcast(0, {n}))
end

defnp gradient(p, q, y, metric) do
defnp gradient(p, q, y, opts) do
pq_diff = Nx.new_axis(p - q, 2)
y_diff = Nx.new_axis(y, 1) - Nx.new_axis(y, 0)
distances = pairwise_dist(y, metric)
distances = pairwise_dist(y, opts)

inv_distances = Nx.new_axis(1 / (1 + distances), 2)

Expand All @@ -297,9 +301,9 @@ defmodule Scholar.Manifold.TSNE do
end
end

defnp p_joint(x, perplexity, metric) do
defnp p_joint(x, perplexity, opts) do
{n, _} = Nx.shape(x)
distances = pairwise_dist(x, metric)
distances = pairwise_dist(x, opts)
sigmas = find_sigmas(distances, perplexity)
p_cond = p_conditional(distances, sigmas)
(p_cond + Nx.transpose(p_cond)) / (2 * n)
Expand Down
20 changes: 4 additions & 16 deletions lib/scholar/metrics/clustering.ex
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ defmodule Scholar.Metrics.Clustering do
iex> Scholar.Metrics.Clustering.silhouette_samples(x, labels, num_clusters: 3)
#Nx.Tensor<
f32[5]
[0.0, -0.9782054424285889, 0.0, -0.18546819686889648, -0.5929657816886902]
[0.0, -0.9782054424285889, 0.0, -0.18546827137470245, -0.5929659008979797]
>
"""
deftransform silhouette_samples(x, labels, opts \\ []) do
Expand Down Expand Up @@ -81,7 +81,7 @@ defmodule Scholar.Metrics.Clustering do
iex> Scholar.Metrics.Clustering.silhouette_score(x, labels, num_clusters: 3)
#Nx.Tensor<
f32
-0.35132789611816406
-0.35132792592048645
>
"""
deftransform silhouette_score(x, labels, opts \\ []) do
Expand All @@ -94,21 +94,9 @@ defmodule Scholar.Metrics.Clustering do

defnp inner_and_outer_dist(x, labels, opts) do
num_clusters = opts[:num_clusters]
{num_samples, num_features} = Nx.shape(x)
num_samples = Nx.axis_size(x, 0)
inf = Nx.Constants.infinity(to_float_type(x))
broadcast_shape = {num_samples, num_samples, num_features}

x_a =
x
|> Nx.new_axis(0)
|> Nx.broadcast(broadcast_shape)

x_b =
x
|> Nx.new_axis(1)
|> Nx.broadcast(broadcast_shape)

pairwise_dist = Scholar.Metrics.Distance.euclidean(x_a, x_b, axes: [2])
pairwise_dist = Scholar.Metrics.Distance.pairwise_euclidean(x)
membership_mask = Nx.reshape(labels, {num_samples, 1}) == Nx.iota({1, num_clusters})
cluster_size = membership_mask |> Nx.sum(axes: [0]) |> Nx.reshape({1, num_clusters})
dist_in_cluster = Nx.dot(pairwise_dist, membership_mask)
Expand Down
152 changes: 151 additions & 1 deletion lib/scholar/metrics/distance.ex
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ defmodule Scholar.Metrics.Distance do
merged_type = Nx.Type.merge(Nx.type(x), Nx.type(y))
res = Nx.select(one_zero?, Nx.tensor(0, type: merged_type), res)
one_merged_type = Nx.tensor(1, type: merged_type)
one_merged_type - Nx.select(both_zero?, one_merged_type, res)
Nx.max(0, one_merged_type - Nx.select(both_zero?, one_merged_type, res))
end

@doc """
Expand Down Expand Up @@ -549,4 +549,154 @@ defmodule Scholar.Metrics.Distance do
w = Nx.as_type(w, result_type)
Nx.weighted_mean(x != y, w, axes: opts[:axes]) |> Nx.as_type(result_type)
end

@doc """
Pairwise squared euclidean distance.
## Examples
iex> x = Nx.iota({6, 6})
iex> y = Nx.reverse(x)
iex> Scholar.Metrics.Distance.pairwise_squared_euclidean(x, y)
#Nx.Tensor<
s64[6][6]
[
[5470, 3526, 2014, 934, 286, 70],
[3526, 2014, 934, 286, 70, 286],
[2014, 934, 286, 70, 286, 934],
[934, 286, 70, 286, 934, 2014],
[286, 70, 286, 934, 2014, 3526],
[70, 286, 934, 2014, 3526, 5470]
]
>
"""
defn pairwise_squared_euclidean(x, y) do
y_norm = Nx.sum(y * y, axes: [1]) |> Nx.new_axis(0)
x_norm = Nx.sum(x * x, axes: [1], keep_axes: true)
Nx.max(0, x_norm + y_norm - 2 * Nx.dot(x, [-1], y, [-1]))
end

@doc """
Pairwise squared euclidean distance. It is equivalent to
Scholar.Metrics.Distance.pairwise_squared_euclidean(x, x)
## Examples
iex> x = Nx.iota({6, 6})
iex> Scholar.Metrics.Distance.pairwise_squared_euclidean(x)
#Nx.Tensor<
s64[6][6]
[
[0, 216, 864, 1944, 3456, 5400],
[216, 0, 216, 864, 1944, 3456],
[864, 216, 0, 216, 864, 1944],
[1944, 864, 216, 0, 216, 864],
[3456, 1944, 864, 216, 0, 216],
[5400, 3456, 1944, 864, 216, 0]
]
>
"""
defn pairwise_squared_euclidean(x) do
x_norm = Nx.sum(x * x, axes: [1], keep_axes: true)
Nx.max(0, x_norm + Nx.transpose(x_norm) - 2 * Nx.dot(x, [-1], x, [-1]))
end

@doc """
Pairwise euclidean distance.
## Examples
iex> x = Nx.iota({6, 6})
iex> y = Nx.reverse(x)
iex> Scholar.Metrics.Distance.pairwise_euclidean(x, y)
#Nx.Tensor<
f32[6][6]
[
[73.9594497680664, 59.380130767822266, 44.87761306762695, 30.561412811279297, 16.911535263061523, 8.366600036621094],
[59.380130767822266, 44.87761306762695, 30.561412811279297, 16.911535263061523, 8.366600036621094, 16.911535263061523],
[44.87761306762695, 30.561412811279297, 16.911535263061523, 8.366600036621094, 16.911535263061523, 30.561412811279297],
[30.561412811279297, 16.911535263061523, 8.366600036621094, 16.911535263061523, 30.561412811279297, 44.87761306762695],
[16.911535263061523, 8.366600036621094, 16.911535263061523, 30.561412811279297, 44.87761306762695, 59.380130767822266],
[8.366600036621094, 16.911535263061523, 30.561412811279297, 44.87761306762695, 59.380130767822266, 73.9594497680664]
]
>
"""
defn pairwise_euclidean(x, y) do
Nx.sqrt(pairwise_squared_euclidean(x, y))
end

@doc """
Pairwise euclidean distance. It is equivalent to
Scholar.Metrics.Distance.pairwise_euclidean(x, x)
## Examples
iex> x = Nx.iota({6, 6})
iex> Scholar.Metrics.Distance.pairwise_euclidean(x)
#Nx.Tensor<
f32[6][6]
[
[0.0, 14.696938514709473, 29.393877029418945, 44.090816497802734, 58.78775405883789, 73.48469543457031],
[14.696938514709473, 0.0, 14.696938514709473, 29.393877029418945, 44.090816497802734, 58.78775405883789],
[29.393877029418945, 14.696938514709473, 0.0, 14.696938514709473, 29.393877029418945, 44.090816497802734],
[44.090816497802734, 29.393877029418945, 14.696938514709473, 0.0, 14.696938514709473, 29.393877029418945],
[58.78775405883789, 44.090816497802734, 29.393877029418945, 14.696938514709473, 0.0, 14.696938514709473],
[73.48469543457031, 58.78775405883789, 44.090816497802734, 29.393877029418945, 14.696938514709473, 0.0]
]
>
"""
defn pairwise_euclidean(x) do
Nx.sqrt(pairwise_squared_euclidean(x))
end

@doc """
Pairwise cosine distance.
## Examples
iex> x = Nx.iota({6, 6})
iex> y = Nx.reverse(x)
iex> Scholar.Metrics.Distance.pairwise_cosine(x, y)
#Nx.Tensor<
f32[6][6]
[
[0.2050153613090515, 0.21226388216018677, 0.22395789623260498, 0.24592703580856323, 0.30156970024108887, 0.6363636255264282],
[0.03128105401992798, 0.03429150581359863, 0.039331674575805664, 0.049365341663360596, 0.07760530710220337, 0.30156970024108887],
[0.014371514320373535, 0.01644366979598999, 0.020004630088806152, 0.02736520767211914, 0.049365341663360596, 0.24592703580856323],
[0.0091819167137146, 0.010854601860046387, 0.013785064220428467, 0.020004630088806152, 0.039331674575805664, 0.22395789623260498],
[0.006820023059844971, 0.008272230625152588, 0.010854601860046387, 0.01644366979598999, 0.03429150581359863, 0.21226388216018677],
[0.005507469177246094, 0.006820023059844971, 0.0091819167137146, 0.014371514320373535, 0.03128105401992798, 0.2050153613090515]
]
>
"""
defn pairwise_cosine(x, y) do
x_normalized = Scholar.Preprocessing.normalize(x, axes: [1])
y_normalized = Scholar.Preprocessing.normalize(y, axes: [1])
Nx.max(0, 1 - Nx.dot(x_normalized, [-1], y_normalized, [-1]))
end

@doc """
Pairwise cosine distance. It is equivalent to
Scholar.Metrics.Distance.pairwise_euclidean(x, x)
## Examples
iex> x = Nx.iota({6, 6})
iex> Scholar.Metrics.Distance.pairwise_cosine(x)
#Nx.Tensor<
f32[6][6]
[
[0.0, 0.0793418288230896, 0.1139642596244812, 0.13029760122299194, 0.1397092342376709, 0.14581435918807983],
[0.0793418288230896, 0.0, 0.0032819509506225586, 0.006624102592468262, 0.008954286575317383, 0.01060718297958374],
[0.1139642596244812, 0.0032819509506225586, 1.1920928955078125e-7, 5.82277774810791e-4, 0.0013980269432067871, 0.0020949840545654297],
[0.13029760122299194, 0.006624102592468262, 5.82277774810791e-4, 5.960464477539063e-8, 1.7595291137695312e-4, 4.686713218688965e-4],
[0.1397092342376709, 0.008954286575317383, 0.0013980269432067871, 1.7595291137695312e-4, 0.0, 7.027387619018555e-5],
[0.14581435918807983, 0.01060718297958374, 0.0020949840545654297, 4.686713218688965e-4, 7.027387619018555e-5, 0.0]
]
>
"""
defn pairwise_cosine(x) do
x_normalized = Scholar.Preprocessing.normalize(x, axes: [1])
Nx.max(0, 1 - Nx.dot(x_normalized, [-1], x_normalized, [-1]))
end
end
Loading

0 comments on commit 07405d6

Please sign in to comment.