diff --git a/kmc2.pyx b/kmc2.pyx index c361bb0..e1758d0 100644 --- a/kmc2.pyx +++ b/kmc2.pyx @@ -107,15 +107,10 @@ def kmc2(X, k, chain_length=200, afkmc2=True, random_state=None, weights=None): # Markov chain for j in range(q_cand.shape[0]): cand_prob = p_cand[j]/q_cand[j] - if j == 0 or curr_prob == 0.0: - # Init new chain + if j == 0 or curr_prob == 0.0 or cand_prob/curr_prob > rand_a[j]: + # Init new chain Metropolis-Hastings step curr_ind = j curr_prob = cand_prob - else: - # Metropolis-Hastings step - if cand_prob/curr_prob > rand_a[j]: - curr_ind = j - curr_prob = cand_prob rel_row = X[cand_ind[curr_ind], :] centers[i+1, :] = rel_row.todense().flatten() if sparse else rel_row return centers