Skip to content

Commit

Permalink
Add suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
msluszniak committed Oct 19, 2023
1 parent ff46a42 commit 9b74d9a
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions lib/scholar/cluster/affinity_propagation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ defmodule Scholar.Cluster.AffinityPropagation do
If `true`, the learning loop is unrolled.
"""
],
convergence_iter: [
converge_after: [
type: :pos_integer,
default: 15,
doc: ~S"""
Expand Down Expand Up @@ -127,6 +127,7 @@ defmodule Scholar.Cluster.AffinityPropagation do
iterations = opts[:iterations]
damping_factor = opts[:damping_factor]
self_preference = opts[:self_preference]
converge_after = opts[:converge_after]
num_samples = Nx.axis_size(data, 0)

{initial_a, initial_r, s, affinity_matrix} =
Expand All @@ -143,9 +144,7 @@ defmodule Scholar.Cluster.AffinityPropagation do

range = Nx.iota({n})

Nx.broadcast(Nx.tensor(0, type: Nx.type(data)), {n, opts[:convergence_iter]})

e = Nx.broadcast(Nx.s64(0), {num_samples, opts[:convergence_iter]})
e = Nx.broadcast(Nx.s64(0), {num_samples, converge_after})
stop = Nx.u8(0)

{{a, r, it}, _} =
Expand Down Expand Up @@ -178,12 +177,12 @@ defmodule Scholar.Cluster.AffinityPropagation do

curr_e = Nx.take_diagonal(a) + Nx.take_diagonal(r) > 0
curr_e_slice = Nx.reshape(curr_e, {:auto, 1})
e = Nx.put_slice(e, [0, Nx.remainder(i, opts[:convergence_iter])], curr_e_slice)
e = Nx.put_slice(e, [0, Nx.remainder(i, converge_after)], curr_e_slice)
k = Nx.sum(curr_e, axes: [0])

stop = if i >= opts[:convergence_iter] do
stop = if i >= converge_after do
se = Nx.sum(e, axes: [1])
unconverged = Nx.sum((se == 0) + (se == opts[:convergence_iter])) != num_samples
unconverged = Nx.sum((se == 0) + (se == converge_after)) != num_samples
if (not unconverged and k > 0) or i == iterations do
Nx.u8(1)
else
Expand Down

0 comments on commit 9b74d9a

Please sign in to comment.