Skip to content

Commit

Permalink
Simplify AP
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Oct 24, 2023
1 parent 84c8878 commit 4b3e824
Showing 1 changed file with 20 additions and 35 deletions.
55 changes: 20 additions & 35 deletions lib/scholar/cluster/affinity_propagation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ defmodule Scholar.Cluster.AffinityPropagation do
current value is maintained relative to incoming values (weighted 1 - damping).
"""
],
self_preference: [
preference: [
type: :float,
doc: """
Preferences for each point - points with larger values of preferences
Expand Down Expand Up @@ -121,13 +121,11 @@ defmodule Scholar.Cluster.AffinityPropagation do
data = to_float(data)
iterations = opts[:iterations]
damping_factor = opts[:damping_factor]
self_preference = opts[:self_preference]
converge_after = opts[:converge_after]
num_samples = Nx.axis_size(data, 0)

{initial_a, initial_r, s} = initialize_matrices(data, self_preference: self_preference)
n = Nx.axis_size(data, 0)
s = initialize_similarity(data, opts)

{n, _} = Nx.shape(initial_a)
zero_n = Nx.tensor(0, type: Nx.type(s)) |> Nx.broadcast({n, n})
{normal, _new_key} = Nx.Random.normal(key, 0, 1, shape: {n, n}, type: Nx.type(s))

s =
Expand All @@ -138,11 +136,11 @@ defmodule Scholar.Cluster.AffinityPropagation do

range = Nx.iota({n})

e = Nx.broadcast(Nx.s64(0), {num_samples, converge_after})
e = Nx.broadcast(Nx.s64(0), {n, converge_after})
stop = Nx.u8(0)

{{a, r, it}, _} =
while {{a = initial_a, r = initial_r, i = 0}, {s, range, stop, e}},
while {{a = zero_n, r = zero_n, i = 0}, {s, range, stop, e}},
i < iterations and not stop do
temp = a + s
indices = Nx.argmax(temp, axis: 1)
Expand Down Expand Up @@ -177,7 +175,7 @@ defmodule Scholar.Cluster.AffinityPropagation do
stop =
if i >= converge_after do
se = Nx.sum(e, axes: [1])
unconverged = Nx.sum((se == 0) + (se == converge_after)) != num_samples
unconverged = Nx.sum((se == 0) + (se == converge_after)) != n

if (not unconverged and k > 0) or i == iterations do
Nx.u8(1)
Expand Down Expand Up @@ -232,6 +230,19 @@ defmodule Scholar.Cluster.AffinityPropagation do
}
end

defnp initialize_similarity(data, opts \\ []) do
n = Nx.axis_size(data, 0)
dist = -Scholar.Metrics.Distance.pairwise_squared_euclidean(data)

fill_in =
case opts[:preference] do
nil -> Nx.broadcast(Nx.median(dist), {n})
preference -> Nx.broadcast(preference, {n})
end

Nx.put_diagonal(dist, fill_in)
end

@doc """
Optionally prune clusters, indices, and labels to only valid entries.
Expand Down Expand Up @@ -306,30 +317,4 @@ defmodule Scholar.Cluster.AffinityPropagation do
Nx.select(Nx.is_nan(dist), Nx.Constants.infinity(Nx.type(dist)), dist)
|> Nx.argmin(axis: 1)
end

defnp initialize_matrices(data, opts \\ []) do
{n, _} = Nx.shape(data)
self_preference = opts[:self_preference]

similarity_matrix = initialize_similarity(data, self_preference: self_preference)

zero = Nx.tensor(0, type: Nx.type(similarity_matrix))
availability_matrix = Nx.broadcast(zero, {n, n})
responsibility_matrix = Nx.broadcast(zero, {n, n})

{availability_matrix, responsibility_matrix, similarity_matrix}
end

defnp initialize_similarity(data, opts \\ []) do
n = Nx.axis_size(data, 0)
dist = -Scholar.Metrics.Distance.pairwise_squared_euclidean(data)

fill_in =
case opts[:self_preference] do
nil -> Nx.broadcast(Nx.median(dist), {n})
self_preference -> Nx.broadcast(self_preference, {n})
end

Nx.put_diagonal(dist, fill_in)
end
end

0 comments on commit 4b3e824

Please sign in to comment.