From 8a583a19582ecf122e0db280f5d9f1fc91acaba1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Valim?= Date: Tue, 24 Oct 2023 10:16:45 +0200 Subject: [PATCH] Simplify self preferences --- lib/scholar/cluster/affinity_propagation.ex | 25 ++++++++++----------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/lib/scholar/cluster/affinity_propagation.ex b/lib/scholar/cluster/affinity_propagation.ex index 17b8a795..b4ad46e0 100644 --- a/lib/scholar/cluster/affinity_propagation.ex +++ b/lib/scholar/cluster/affinity_propagation.ex @@ -42,8 +42,15 @@ defmodule Scholar.Cluster.AffinityPropagation do """ ], self_preference: [ - type: {:or, [:float, :boolean, :integer]}, - doc: "Self preference." + type: :float, + doc: """ + Preferences for each point - points with larger values of preferences + are more likely to be chosen as exemplars. The number of clusters is + influenced by this option. If the preferences are not passed as arguments, + they will be set to the median of the input similarities (resulting in a + moderate number of clusters). For a smaller amount of clusters, this can + be set to the minimum value of the similarities. + """ ], key: [ type: {:custom, Scholar.Options, :key, []}, @@ -117,7 +124,6 @@ defmodule Scholar.Cluster.AffinityPropagation do """ deftransform fit(data, opts \\ []) do opts = NimbleOptions.validate!(opts, @opts_schema) - opts = Keyword.update(opts, :self_preference, false, fn x -> x end) key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end) fit_n(data, key, NimbleOptions.validate!(opts, @opts_schema)) end @@ -337,19 +343,12 @@ defmodule Scholar.Cluster.AffinityPropagation do defnp initialize_similarities(data, opts \\ []) do n = Nx.axis_size(data, 0) - self_preference = opts[:self_preference] - dist = -Scholar.Metrics.Distance.pairwise_squared_euclidean(data) fill_in = - cond do - self_preference == false -> - Nx.broadcast(Nx.median(dist), {n}) - - true -> - if Nx.size(self_preference) == 1, - do: Nx.broadcast(self_preference, {n}), - else: self_preference + case opts[:self_preference] do + nil -> Nx.broadcast(Nx.median(dist), {n}) + self_preference -> Nx.broadcast(self_preference, {n}) end s_modified = dist |> Nx.put_diagonal(fill_in)