Skip to content

Commit

Permalink
Simplify self preferences
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim committed Oct 24, 2023
1 parent 8cf6d73 commit 8a583a1
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions lib/scholar/cluster/affinity_propagation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,15 @@ defmodule Scholar.Cluster.AffinityPropagation do
"""
],
self_preference: [
type: {:or, [:float, :boolean, :integer]},
doc: "Self preference."
type: :float,
doc: """
Preferences for each point - points with larger values of preferences
are more likely to be chosen as exemplars. The number of clusters is
influenced by this option. If the preferences are not passed as arguments,
they will be set to the median of the input similarities (resulting in a
moderate number of clusters). For a smaller amount of clusters, this can
be set to the minimum value of the similarities.
"""
],
key: [
type: {:custom, Scholar.Options, :key, []},
Expand Down Expand Up @@ -117,7 +124,6 @@ defmodule Scholar.Cluster.AffinityPropagation do
"""
deftransform fit(data, opts \\ []) do
opts = NimbleOptions.validate!(opts, @opts_schema)
opts = Keyword.update(opts, :self_preference, false, fn x -> x end)
key = Keyword.get_lazy(opts, :key, fn -> Nx.Random.key(System.system_time()) end)
fit_n(data, key, NimbleOptions.validate!(opts, @opts_schema))
end
Expand Down Expand Up @@ -337,19 +343,12 @@ defmodule Scholar.Cluster.AffinityPropagation do

defnp initialize_similarities(data, opts \\ []) do
n = Nx.axis_size(data, 0)
self_preference = opts[:self_preference]

dist = -Scholar.Metrics.Distance.pairwise_squared_euclidean(data)

fill_in =
cond do
self_preference == false ->
Nx.broadcast(Nx.median(dist), {n})

true ->
if Nx.size(self_preference) == 1,
do: Nx.broadcast(self_preference, {n}),
else: self_preference
case opts[:self_preference] do
nil -> Nx.broadcast(Nx.median(dist), {n})
self_preference -> Nx.broadcast(self_preference, {n})
end

s_modified = dist |> Nx.put_diagonal(fill_in)
Expand Down

0 comments on commit 8a583a1

Please sign in to comment.