Skip to content

Commit

Permalink
hoist the robust check
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Nov 21, 2023
1 parent 137f03a commit e6988eb
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions src/fit_em.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,18 @@ function E_step!(
) where {T<:AbstractFloat}
# evaluate likelihood for each type k
for k in eachindex(dists)
logα = log(α[k])
robust && !isfinite(logα) && continue
distk = dists[k]
for n in eachindex(y)
logp = logpdf(distk, y[n])
if !robust || isfinite(logp)
logα, distk = log(α[k]), dists[k]
if robust
isfinite(logα) || continue
for n in eachindex(y)
logp = logpdf(distk, y[n])
isfinite(logp) || continue
LL[n, k] = logα + logp
end
else
for n in eachindex(y)
LL[n, k] = logα + logpdf(distk, y[n])
end
end
end
robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf)))
Expand All @@ -151,14 +155,18 @@ function E_step!(
)
# evaluate likelihood for each type k
for k in eachindex(dists)
logα = log(α[k])
robust && !isfinite(logα) && continue
distk = dists[k]
for n in axes(y, 2)
logp = logpdf(distk, y[:, n])
if !robust || isfinite(logp)
logα, distk = log(α[k]), dists[k]
if robust
isfinite(logα) || continue
for n in axes(y, 2)
logp = logpdf(distk, y[:, n])
isfinite(logp) || continue
LL[n, k] = logα + logp
end
else
for n in axes(y, 2)
LL[n, k] = logα + logpdf(distk, y[:, n])
end
end
end
# get posterior of each category
Expand Down

0 comments on commit e6988eb

Please sign in to comment.