Skip to content

Commit

Permalink
Allow AP preference to be any Nx computation, default to reduce_min (e…
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim authored Oct 24, 2023
1 parent 4b3e824 commit 8fa66fc
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 31 deletions.
59 changes: 31 additions & 28 deletions lib/scholar/cluster/affinity_propagation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ defmodule Scholar.Cluster.AffinityPropagation do
"""
],
preference: [
type: :float,
type: {:or, [:float, :atom]},
default: :reduce_min,
doc: """
Preferences for each point - points with larger values of preferences
are more likely to be chosen as exemplars. The number of clusters is
influenced by this option. If the preferences are not passed as arguments,
they will be set to the median of the input similarities (resulting in a
moderate number of clusters). For a smaller amount of clusters, this can
be set to the minimum value of the similarities.
How to compute the preferences for each point - points with larger values
of preferences are more likely to be chosen as exemplars. The number of clusters is
influenced by this option.
The preferences is either an atom, each is a `Nx` reduction function to
apply on the input similarities (such as `:reduce_min`, `:median`, `:mean`,
etc) or a float.
"""
],
key: [
Expand Down Expand Up @@ -94,21 +96,21 @@ defmodule Scholar.Cluster.AffinityPropagation do
## Examples
iex> key = Nx.Random.key(42)
iex> x = Nx.tensor([[12,5,78,2], [1,-5,7,32], [-1,3,6,1], [1,-2,5,2]])
iex> x = Nx.tensor([[12,5,78,2], [9,3,81,-2], [-1,3,6,1], [1,-2,5,2]])
iex> Scholar.Cluster.AffinityPropagation.fit(x, key: key)
%Scholar.Cluster.AffinityPropagation{
labels: Nx.tensor([0, 3, 3, 3]),
cluster_centers_indices: Nx.tensor([0, -1, -1, 3]),
labels: Nx.tensor([0, 0, 2, 2]),
cluster_centers_indices: Nx.tensor([0, -1, 2, -1]),
cluster_centers: Nx.tensor(
[
[12.0, 5.0, 78.0, 2.0],
[:infinity, :infinity, :infinity, :infinity],
[:infinity, :infinity, :infinity, :infinity],
[1.0, -2.0, 5.0, 2.0]
[-1.0, 3.0, 6.0, 1.0],
[:infinity, :infinity, :infinity, :infinity]
]
),
num_clusters: Nx.tensor(2, type: :u64),
iterations: Nx.tensor(18, type: :s64)
iterations: Nx.tensor(22, type: :s64)
}
"""
deftransform fit(data, opts \\ []) do
Expand Down Expand Up @@ -233,14 +235,15 @@ defmodule Scholar.Cluster.AffinityPropagation do
defnp initialize_similarity(data, opts \\ []) do
n = Nx.axis_size(data, 0)
dist = -Scholar.Metrics.Distance.pairwise_squared_euclidean(data)
preference = initialize_preference(dist, opts[:preference])
Nx.put_diagonal(dist, Nx.broadcast(preference, {n}))
end

fill_in =
case opts[:preference] do
nil -> Nx.broadcast(Nx.median(dist), {n})
preference -> Nx.broadcast(preference, {n})
end

Nx.put_diagonal(dist, fill_in)
deftransformp initialize_preference(dist, preference) do
cond do
is_atom(preference) -> apply(Nx, preference, [dist])
is_float(preference) -> preference
end
end

@doc """
Expand All @@ -251,20 +254,20 @@ defmodule Scholar.Cluster.AffinityPropagation do
## Examples
iex> key = Nx.Random.key(42)
iex> x = Nx.tensor([[12,5,78,2], [1,-5,7,32], [-1,3,6,1], [1,-2,5,2]])
iex> x = Nx.tensor([[12,5,78,2], [9,3,81,-2], [-1,3,6,1], [1,-2,5,2]])
iex> model = Scholar.Cluster.AffinityPropagation.fit(x, key: key)
iex> Scholar.Cluster.AffinityPropagation.prune(model)
%Scholar.Cluster.AffinityPropagation{
labels: Nx.tensor([0, 1, 1, 1]),
cluster_centers_indices: Nx.tensor([0, 3]),
labels: Nx.tensor([0, 0, 1, 1]),
cluster_centers_indices: Nx.tensor([0, 2]),
cluster_centers: Nx.tensor(
[
[12.0, 5.0, 78.0, 2.0],
[1.0, -2.0, 5.0, 2.0]
[-1.0, 3.0, 6.0, 1.0]
]
),
num_clusters: Nx.tensor(2, type: :u64),
iterations: Nx.tensor(18, type: :s64)
iterations: Nx.tensor(22, type: :s64)
}
"""
def prune(
Expand Down Expand Up @@ -302,13 +305,13 @@ defmodule Scholar.Cluster.AffinityPropagation do
## Examples
iex> key = Nx.Random.key(42)
iex> x = Nx.tensor([[12,5,78,2], [1,5,7,32], [1,3,6,1], [1,2,5,2]])
iex> x = Nx.tensor([[12,5,78,2], [9,3,81,-2], [-1,3,6,1], [1,-2,5,2]])
iex> model = Scholar.Cluster.AffinityPropagation.fit(x, key: key)
iex> model = Scholar.Cluster.AffinityPropagation.prune(model)
iex> Scholar.Cluster.AffinityPropagation.predict(model, Nx.tensor([[1,6,2,6], [8,3,8,2]]))
iex> Scholar.Cluster.AffinityPropagation.predict(model, Nx.tensor([[10,3,50,6], [8,3,8,2]]))
#Nx.Tensor<
s64[2]
[1, 1]
[0, 1]
>
"""
defn predict(%__MODULE__{cluster_centers: cluster_centers} = _model, x) do
Expand Down
6 changes: 3 additions & 3 deletions test/scholar/cluster/affinity_propagation_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ defmodule Scholar.Cluster.AffinityPropagationTest do
end

test "fit and compute_values" do
model = AffinityPropagation.fit(x(), key: key())
model = AffinityPropagation.fit(x(), key: key(), preference: :median)

model = AffinityPropagation.prune(model)

Expand All @@ -94,14 +94,14 @@ defmodule Scholar.Cluster.AffinityPropagationTest do
end

test "predict with pruning" do
model = AffinityPropagation.fit(x(), key: key())
model = AffinityPropagation.fit(x(), key: key(), preference: :median)
model = AffinityPropagation.prune(model)
preds = AffinityPropagation.predict(model, x_test())
assert preds == Nx.tensor([0, 2, 0, 5, 5, 5, 2, 2, 5, 2])
end

test "predict without pruning" do
model = AffinityPropagation.fit(x(), key: key())
model = AffinityPropagation.fit(x(), key: key(), preference: :median)
preds = AffinityPropagation.predict(model, x_test())
assert preds == Nx.tensor([2, 9, 2, 34, 34, 34, 9, 9, 34, 9])
end
Expand Down

0 comments on commit 8fa66fc

Please sign in to comment.