From e6988ebc9e96e6fa8b9810f0b94a0da3e2000f35 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Tue, 21 Nov 2023 05:13:41 -0600 Subject: [PATCH] hoist the `robust` check --- src/fit_em.jl | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/src/fit_em.jl b/src/fit_em.jl index bf89f35..c20dbb0 100644 --- a/src/fit_em.jl +++ b/src/fit_em.jl @@ -124,14 +124,18 @@ function E_step!( ) where {T<:AbstractFloat} # evaluate likelihood for each type k for k in eachindex(dists) - logα = log(α[k]) - robust && !isfinite(logα) && continue - distk = dists[k] - for n in eachindex(y) - logp = logpdf(distk, y[n]) - if !robust || isfinite(logp) + logα, distk = log(α[k]), dists[k] + if robust + isfinite(logα) || continue + for n in eachindex(y) + logp = logpdf(distk, y[n]) + isfinite(logp) || continue LL[n, k] = logα + logp end + else + for n in eachindex(y) + LL[n, k] = logα + logpdf(distk, y[n]) + end end end robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf))) @@ -151,14 +155,18 @@ function E_step!( ) # evaluate likelihood for each type k for k in eachindex(dists) - logα = log(α[k]) - robust && !isfinite(logα) && continue - distk = dists[k] - for n in axes(y, 2) - logp = logpdf(distk, y[:, n]) - if !robust || isfinite(logp) + logα, distk = log(α[k]), dists[k] + if robust + isfinite(logα) || continue + for n in axes(y, 2) + logp = logpdf(distk, y[:, n]) + isfinite(logp) || continue LL[n, k] = logα + logp end + else + for n in axes(y, 2) + LL[n, k] = logα + logpdf(distk, y[:, n]) + end end end # get posterior of each category