From 4b3e824623e43ade4e4df433a2f2edb2d5bd94fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Tue, 24 Oct 2023 13:22:47 +0200 Subject: [PATCH] Simplify AP --- lib/scholar/cluster/affinity_propagation.ex | 55 ++++++++------------- 1 file changed, 20 insertions(+), 35 deletions(-) diff --git a/lib/scholar/cluster/affinity_propagation.ex b/lib/scholar/cluster/affinity_propagation.ex index ac2bb5b5..2e562933 100644 --- a/lib/scholar/cluster/affinity_propagation.ex +++ b/lib/scholar/cluster/affinity_propagation.ex @@ -39,7 +39,7 @@ defmodule Scholar.Cluster.AffinityPropagation do current value is maintained relative to incoming values (weighted 1 - damping). """ ], - self_preference: [ + preference: [ type: :float, doc: """ Preferences for each point - points with larger values of preferences @@ -121,13 +121,11 @@ defmodule Scholar.Cluster.AffinityPropagation do data = to_float(data) iterations = opts[:iterations] damping_factor = opts[:damping_factor] - self_preference = opts[:self_preference] converge_after = opts[:converge_after] - num_samples = Nx.axis_size(data, 0) - - {initial_a, initial_r, s} = initialize_matrices(data, self_preference: self_preference) + n = Nx.axis_size(data, 0) + s = initialize_similarity(data, opts) - {n, _} = Nx.shape(initial_a) + zero_n = Nx.tensor(0, type: Nx.type(s)) |> Nx.broadcast({n, n}) {normal, _new_key} = Nx.Random.normal(key, 0, 1, shape: {n, n}, type: Nx.type(s)) s = @@ -138,11 +136,11 @@ defmodule Scholar.Cluster.AffinityPropagation do range = Nx.iota({n}) - e = Nx.broadcast(Nx.s64(0), {num_samples, converge_after}) + e = Nx.broadcast(Nx.s64(0), {n, converge_after}) stop = Nx.u8(0) {{a, r, it}, _} = - while {{a = initial_a, r = initial_r, i = 0}, {s, range, stop, e}}, + while {{a = zero_n, r = zero_n, i = 0}, {s, range, stop, e}}, i < iterations and not stop do temp = a + s indices = Nx.argmax(temp, axis: 1) @@ -177,7 +175,7 @@ defmodule Scholar.Cluster.AffinityPropagation do stop = if i >= converge_after do se = Nx.sum(e, axes: [1]) - unconverged = Nx.sum((se == 0) + (se == converge_after)) != num_samples + unconverged = Nx.sum((se == 0) + (se == converge_after)) != n if (not unconverged and k > 0) or i == iterations do Nx.u8(1) @@ -232,6 +230,19 @@ defmodule Scholar.Cluster.AffinityPropagation do } end + defnp initialize_similarity(data, opts \\ []) do + n = Nx.axis_size(data, 0) + dist = -Scholar.Metrics.Distance.pairwise_squared_euclidean(data) + + fill_in = + case opts[:preference] do + nil -> Nx.broadcast(Nx.median(dist), {n}) + preference -> Nx.broadcast(preference, {n}) + end + + Nx.put_diagonal(dist, fill_in) + end + @doc """ Optionally prune clusters, indices, and labels to only valid entries. @@ -306,30 +317,4 @@ defmodule Scholar.Cluster.AffinityPropagation do Nx.select(Nx.is_nan(dist), Nx.Constants.infinity(Nx.type(dist)), dist) |> Nx.argmin(axis: 1) end - - defnp initialize_matrices(data, opts \\ []) do - {n, _} = Nx.shape(data) - self_preference = opts[:self_preference] - - similarity_matrix = initialize_similarity(data, self_preference: self_preference) - - zero = Nx.tensor(0, type: Nx.type(similarity_matrix)) - availability_matrix = Nx.broadcast(zero, {n, n}) - responsibility_matrix = Nx.broadcast(zero, {n, n}) - - {availability_matrix, responsibility_matrix, similarity_matrix} - end - - defnp initialize_similarity(data, opts \\ []) do - n = Nx.axis_size(data, 0) - dist = -Scholar.Metrics.Distance.pairwise_squared_euclidean(data) - - fill_in = - case opts[:self_preference] do - nil -> Nx.broadcast(Nx.median(dist), {n}) - self_preference -> Nx.broadcast(self_preference, {n}) - end - - Nx.put_diagonal(dist, fill_in) - end end