-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
256 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,11 @@ uuid = "e1fe09cc-5134-44c2-a941-50f4cd97986a" | |
authors = ["David Métivier <[email protected]> and contributors"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" | ||
|
||
[compat] | ||
julia = "1" | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,47 @@ | ||
# ExpectationMaximization | ||
|
||
This package provides a simple implementation of the Expectation Maximization algorithm used to fit mixture models. | ||
Due to [Julia](https://julialang.org/) amazing [multiple dispatch](https://www.youtube.com/watch?v=kc9HwsxE1OY) systems and the [Distributions](https://juliastats.org/Distributions.jl/stable/) package, the code is very generic i.e., mixture of all common distributions should be supported. | ||
|
||
## Example | ||
|
||
```julia | ||
using Distributions | ||
using ExpectationMaximization | ||
``` | ||
|
||
### Model | ||
|
||
```julia | ||
N = 50000 | ||
θ₁ = 10 | ||
θ₂ = 5 | ||
α = 0.2 | ||
β = 0.3 | ||
# Mixture Model here one can put any classical distributions | ||
mix_true = MixtureModel([Exponential(θ₁), Gamma(α, θ₂)], [β, 1 - β]) | ||
|
||
# Generate N samples from the mixture | ||
y = rand(mix_true, N) | ||
``` | ||
|
||
### Inference | ||
|
||
```julia | ||
# Initial guess | ||
mix_mle = fit_mle(mix_guess, y; display = :iter, tol = 1e-3, robust = false, infos = false) | ||
|
||
# Fit the MLE with the EM algorithm | ||
mix_mle = fit_mle(mix_guess, y; display = :iter, tol = 1e-3, robust = false, infos = false) | ||
``` | ||
|
||
### Verify results | ||
|
||
```julia | ||
rtol = 5e-2 | ||
p = params(mix_mle)[1] # (θ₁, (α, θ₂)) | ||
isapprox(β, probs(mix_mle)[1]; rtol = rtol) | ||
isapprox(θ₁, p[1]...; rtol = rtol) | ||
isapprox(α, p[2][1]; rtol = rtol) | ||
isapprox(θ₂, p[2][2]; rtol = rtol) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
""" | ||
fit_mle(mix::MixtureModel, y; display = :none, maxiter = 100, tol = 1e-3, robust = false) | ||
fit_em use Expectation Maximization (EM) algorithm to maximize the Loglikelihood (fit) the mixture to an i.i.d sample `y`. | ||
The `mix` agrument is a mixture that is used to initilize the EM algorithm. | ||
""" | ||
function fit_mle(mix::MixtureModel, y::AbstractVector; display = :none, maxiter = 1000, tol = 1e-3, robust = false, infos = false) | ||
|
||
@argcheck display in [:none, :iter, :final] | ||
@argcheck maxiter >= 0 | ||
|
||
N, K = length(y), ncomponents(mix) | ||
history = Dict("converged" => false, "iterations" => 0, "logtots" => zeros(0)) | ||
|
||
# Allocate memory for in-place updates | ||
|
||
LL = zeros(N, K) | ||
γ = similar(LL) | ||
c = zeros(N) | ||
# types = typeof.(components(mix)) | ||
|
||
# Initial parameters | ||
α = copy(probs(mix)) | ||
dists = copy(components(mix)) | ||
|
||
# E-step | ||
# evaluate likelihood for each type k | ||
for k = 1:K | ||
LL[:, k] = log(α[k]) .+ logpdf(dists[k], y) | ||
end | ||
robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf))) | ||
# get posterior of each category | ||
c[:] = logsumexp(LL, dims = 2) | ||
γ[:, :] = exp.(LL .- c) | ||
|
||
# Loglikelihood | ||
logtot = sum(c) | ||
(display == :iter) && println("Iteration 0: logtot = $logtot") | ||
|
||
for it = 1:maxiter | ||
|
||
# M-step | ||
# with γ, maximize (update) the parameters | ||
α[:] = mean(γ, dims = 1) | ||
dists[:] = [fit_mle(dists[k], y, γ[:, k]) for k = 1:K] | ||
|
||
# E-step | ||
# evaluate likelihood for each type k | ||
for k = 1:K | ||
LL[:, k] = log(α[k]) .+ logpdf(dists[k], y) | ||
end | ||
robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf))) | ||
# get posterior of each category | ||
c[:] = logsumexp(LL, dims = 2) | ||
γ[:, :] = exp.(LL .- c) | ||
|
||
# Loglikelihood | ||
logtotp = sum(c) | ||
(display == :iter) && println("Iteration $it: logtot = $logtotp") | ||
|
||
push!(history["logtots"], logtotp) | ||
history["iterations"] += 1 | ||
|
||
if abs(logtotp - logtot) < tol | ||
(display in [:iter, :final]) && | ||
println("EM converged in $it iterations, logtot = $logtotp") | ||
history["converged"] = true | ||
break | ||
end | ||
|
||
logtot = logtotp | ||
end | ||
|
||
if !history["converged"] | ||
if display in [:iter, :final] | ||
println("EM has not converged after $(history["iterations"]) iterations, logtot = $logtot") | ||
end | ||
end | ||
|
||
return infos ? (MixtureModel(dists, α), history) : MixtureModel(dists, α) | ||
end | ||
|
||
function fit_mle(mix::MixtureModel, y::AbstractMatrix; display = :none, maxiter = 1000, tol = 1e-3, robust = false, infos = false) | ||
|
||
@argcheck display in [:none, :iter, :final] | ||
@argcheck maxiter >= 0 | ||
|
||
N, K = size(y, 2), ncomponents(mix) | ||
history = Dict("converged" => false, "iterations" => 0, "logtots" => zeros(0)) | ||
|
||
# Allocate memory for in-place updates | ||
|
||
LL = zeros(N, K) | ||
γ = similar(LL) | ||
c = zeros(N) | ||
# types = typeof.(components(mix)) | ||
|
||
# Initial parameters | ||
α = copy(probs(mix)) | ||
dists = copy(components(mix)) | ||
|
||
# E-step | ||
# evaluate likelihood for each type k | ||
for k = 1:K | ||
LL[:, k] = log(α[k]) .+ logpdf(dists[k], y) | ||
end | ||
robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf))) | ||
# get posterior of each category | ||
c[:] = logsumexp(LL, dims = 2) | ||
γ[:, :] = exp.(LL .- c) | ||
|
||
# Loglikelihood | ||
logtot = sum(c) | ||
(display == :iter) && println("Iteration 0: logtot = $logtot") | ||
|
||
for it = 1:maxiter | ||
|
||
# M-step | ||
# with γ in hand, maximize (update) the parameters | ||
α[:] = mean(γ, dims = 1) | ||
dists[:] = [fit_mle(dists[k], y, γ[:, k]) for k = 1:K] | ||
|
||
# E-step | ||
# evaluate likelihood for each type k | ||
for k = 1:K | ||
LL[:, k] = log(α[k]) .+ logpdf(dists[k], y) | ||
end | ||
robust && replace!(LL, -Inf => nextfloat(-Inf), Inf => log(prevfloat(Inf))) | ||
# get posterior of each category | ||
c[:] = logsumexp(LL, dims = 2) | ||
γ[:, :] = exp.(LL .- c) | ||
|
||
# Loglikelihood | ||
logtotp = sum(c) | ||
(display == :iter) && println("Iteration $it: logtot = $logtotp") | ||
|
||
push!(history["logtots"], logtotp) | ||
history["iterations"] += 1 | ||
|
||
if abs(logtotp - logtot) < tol | ||
(display in [:iter, :final]) && | ||
println("EM converged in $it iterations, logtot = $logtotp") | ||
history["converged"] = true | ||
break | ||
end | ||
|
||
logtot = logtotp | ||
end | ||
|
||
if !history["converged"] | ||
if display in [:iter, :final] | ||
println("EM has not converged after $(history["iterations"]) iterations, logtot = $logtot") | ||
end | ||
end | ||
|
||
return infos ? (MixtureModel(dists, α), history) : MixtureModel(dists, α) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#! What is the convention and or the most efficient way to store these multidim arrays: samples from column/row correspond to one distributions? | ||
#! Here we choose row correspond to one distributions and number of row is the number of realization | ||
#! It seems to be the convention | ||
#TODO Only work for distributions with known fit_mle (cannot be product_distribution of mixture because of the typeof) | ||
""" | ||
Simply extend the `fit_mle` function to multivariate Product distributions. | ||
""" | ||
function fit_mle(g::Product, x::AbstractMatrix) | ||
S = size(x, 1) # row convention | ||
vec_g = g.v | ||
@argcheck S == length(vec_g) | ||
return product_distribution([fit_mle(typeof(vec_g[s]), y) for (s, y) in enumerate(eachrow(x))]) | ||
end | ||
|
||
function fit_mle(g::Product, x::AbstractMatrix, γ::AbstractVector) | ||
S = size(x, 1) # row convention | ||
vec_g = g.v | ||
@argcheck S == length(vec_g) | ||
return product_distribution([fit_mle(typeof(vec_g[s]), y, γ) for (s, y) in enumerate(eachrow(x))]) | ||
end | ||
|
||
""" | ||
Now `fit_mle(Bernoulli(0.2), x)` is accepted in addition of `fit_mle(Bernoulli, x)` this allows compatibility with how `fit_mle(g::Product)` and `fit_mle(g::MixtureModel)` are written. | ||
""" | ||
function fit_mle(g::Distribution{Univariate,Discrete}, args...) | ||
fit_mle(typeof(g), args...) | ||
end | ||
|
||
function fit_mle(g::Distribution{Univariate,Continuous}, args...) | ||
fit_mle(typeof(g), args...) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,22 @@ | ||
using ExpectationMaximization | ||
using Distributions | ||
using Test | ||
|
||
@testset "ExpectationMaximization.jl" begin | ||
# Write your tests here. | ||
N = 50000 | ||
θ₁ = 10 | ||
θ₂ = 5 | ||
α = 0.2 | ||
β = 0.3 | ||
rtol = 5e-2 | ||
mix_true = MixtureModel([Exponential(θ₁), Gamma(α, θ₂)], [β, 1 - β]) | ||
y = rand(mix_true, N) | ||
mix_guess = MixtureModel([Exponential(1), Gamma(0.5, 1)], [0.5, 1 - 0.5]) | ||
mix_mle = fit_mle(mix_guess, y; display = :iter, tol = 1e-3, robust = false, infos = false) | ||
|
||
p = params(mix_mle)[1] | ||
@test isapprox(β, probs(mix_mle)[1]; rtol = rtol) | ||
@test isapprox(θ₁, p[1]...; rtol = rtol) | ||
@test isapprox(α, p[2][1]; rtol = rtol) | ||
@test isapprox(θ₂, p[2][2]; rtol = rtol) | ||
end |