From 8cf6d734c20d201aa8321fb5666fec6ba6c4b51b Mon Sep 17 00:00:00 2001 From: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com> Date: Thu, 19 Oct 2023 16:21:34 +0200 Subject: [PATCH] Add convergence check to AffinityPropagation (#195) --- lib/scholar/cluster/affinity_propagation.ex | 64 ++++++++++++++------- 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/lib/scholar/cluster/affinity_propagation.ex b/lib/scholar/cluster/affinity_propagation.ex index b2523829..17b8a795 100644 --- a/lib/scholar/cluster/affinity_propagation.ex +++ b/lib/scholar/cluster/affinity_propagation.ex @@ -15,14 +15,16 @@ defmodule Scholar.Cluster.AffinityPropagation do :cluster_centers_indices, :affinity_matrix, :cluster_centers, - :num_clusters + :num_clusters, + :iterations ]} defstruct [ :labels, :cluster_centers_indices, :affinity_matrix, :cluster_centers, - :num_clusters + :num_clusters, + :iterations ] @opts_schema [ @@ -56,6 +58,14 @@ defmodule Scholar.Cluster.AffinityPropagation do doc: ~S""" If `true`, the learning loop is unrolled. """ + ], + converge_after: [ + type: :pos_integer, + default: 15, + doc: ~S""" + Number of iterations with no change in the number of estimated clusters + that stops the convergence. + """ ] ] @@ -101,7 +111,8 @@ defmodule Scholar.Cluster.AffinityPropagation do [1.0, -2.0, 5.0, 2.0] ] ), - num_clusters: Nx.tensor(2, type: :u64) + num_clusters: Nx.tensor(2, type: :u64), + iterations: Nx.tensor(18, type: :s64) } """ deftransform fit(data, opts \\ []) do @@ -116,7 +127,8 @@ defmodule Scholar.Cluster.AffinityPropagation do iterations = opts[:iterations] damping_factor = opts[:damping_factor] self_preference = opts[:self_preference] - data = to_float(data) + converge_after = opts[:converge_after] + num_samples = Nx.axis_size(data, 0) {initial_a, initial_r, s, affinity_matrix} = initialize_matrices(data, self_preference: self_preference) @@ -132,9 +144,12 @@ defmodule Scholar.Cluster.AffinityPropagation do range = Nx.iota({n}) - {{a, r}, _} = - while {{a = initial_a, r = initial_r}, {s, range, i = 0}}, - i < iterations do + e = Nx.broadcast(Nx.s64(0), {num_samples, converge_after}) + stop = Nx.u8(0) + + {{a, r, it}, _} = + while {{a = initial_a, r = initial_r, i = 0}, {s, range, stop, e}}, + i < iterations and not stop do temp = a + s indices = Nx.argmax(temp, axis: 1) y = Nx.reduce_max(temp, axes: [1]) @@ -160,7 +175,24 @@ defmodule Scholar.Cluster.AffinityPropagation do temp = temp * (1 - damping_factor) a = a * damping_factor - temp - {{a, r}, {s, range, i + 1}} + curr_e = Nx.take_diagonal(a) + Nx.take_diagonal(r) > 0 + curr_e_slice = Nx.reshape(curr_e, {:auto, 1}) + e = Nx.put_slice(e, [0, Nx.remainder(i, converge_after)], curr_e_slice) + k = Nx.sum(curr_e, axes: [0]) + + stop = + if i >= converge_after do + se = Nx.sum(e, axes: [1]) + unconverged = Nx.sum((se == 0) + (se == converge_after)) != num_samples + + if (not unconverged and k > 0) or i == iterations do + Nx.u8(1) + else + stop + end + end + + {{a, r, i + 1}, {s, range, stop, e}} end diagonals = Nx.take_diagonal(a) + Nx.take_diagonal(r) > 0 @@ -202,7 +234,8 @@ defmodule Scholar.Cluster.AffinityPropagation do cluster_centers_indices: cluster_centers_indices, cluster_centers: cluster_centers, labels: labels, - num_clusters: k + num_clusters: k, + iterations: it } end @@ -233,7 +266,8 @@ defmodule Scholar.Cluster.AffinityPropagation do [1.0, -2.0, 5.0, 2.0] ] ), - num_clusters: Nx.tensor(2, type: :u64) + num_clusters: Nx.tensor(2, type: :u64), + iterations: Nx.tensor(18, type: :s64) } """ def prune( @@ -281,16 +315,6 @@ defmodule Scholar.Cluster.AffinityPropagation do > """ defn predict(%__MODULE__{cluster_centers: cluster_centers} = _model, x) do - {num_clusters, num_features} = Nx.shape(cluster_centers) - {num_samples, _} = Nx.shape(x) - broadcast_shape = {num_samples, num_clusters, num_features} - - Scholar.Metrics.Distance.euclidean( - Nx.new_axis(x, 1) |> Nx.broadcast(broadcast_shape), - Nx.new_axis(cluster_centers, 0) |> Nx.broadcast(broadcast_shape), - axes: [-1] - ) - dist = Scholar.Metrics.Distance.pairwise_euclidean(x, cluster_centers) Nx.select(Nx.is_nan(dist), Nx.Constants.infinity(Nx.type(dist)), dist)