Skip to content

Commit

Permalink
add Laplace weigths fit_mle and "docs" and test
Browse files Browse the repository at this point in the history
  • Loading branch information
dmetivie committed Jan 23, 2023
1 parent abc3325 commit 2c6f6b1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
19 changes: 14 additions & 5 deletions src/that_should_be_in_Distributions.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#! What is the convention and or the most efficient way to store these multidim arrays: samples from column/row correspond to one distributions?
#! Here we choose row correspond to one distributions and number of row is the number of realization
#! It seems to be the convention
#TODO Only work for distributions with known fit_mle (cannot be product_distribution of mixture because of the typeof)
"""
Simply extend the `fit_mle` function to multivariate Product distributions.
fit_mle(g::Product, x::AbstractMatrix)
fit_mle(g::Product, x::AbstractMatrix, γ::AbstractVector)
The `fit_mle` for multivariate Product distributions `g` is the `product_distribution` of `fit_mle` of each components of `g`.
"""
function fit_mle(g::Product, x::AbstractMatrix)
S = size(x, 1) # Distributions convention
Expand Down Expand Up @@ -37,11 +36,21 @@ function fit_mle(g::Distribution{Multivariate,S}, args...) where {S}
fit_mle(typeof(g), args...)
end

# * `fit_mle` (weighted or not) of some distribution

fit_mle(::Type{<:Dirac}, x::AbstractArray{T}) where {T<:Real} = length(unique(x)) == 1 ? Dirac(first(x)) : Dirac(NaN)
function fit_mle(::Type{<:Dirac}, x::AbstractArray{T}, w::AbstractArray{Float64}) where {T<:Real}
n = length(x)
if n != length(w)
throw(DimensionMismatch("Inconsistent array lengths."))
end
return length(unique(x[findall(!iszero, w)])) == 1 ? Dirac(first(x)) : Dirac(NaN)
end

function fit_mle(::Type{<:Laplace}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
xc = similar(x)
copyto!(xc, x)
m = median(xc, weights(w))
xc .= abs.(x .- m)
return Laplace(m, mean(xc, weights(w)))
end
42 changes: 42 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,48 @@ end
mix_mle = fit_mle(mix_guess, y; display=:iter, tol=1e-3, robust=false, infos=false)
y_guess = rand(seed, mix_mle, N)

@test probs(mix_mle) [β, 1 - β] rtol = rtol
p = params(mix_mle)[1]
@test p[1][2] [α, 1 - α] rtol = rtol
@test θ₁ p[1][1][1][1] rtol = rtol
@test σ₁ p[1][1][1][2] rtol = rtol
@test θ₂ p[1][1][2][1] rtol = rtol
@test σ₂ p[1][1][2][2] rtol = rtol
@test θ₀ p[2][1] rtol = rtol
@test σ₀ p[2][2] rtol = rtol
end

@testset "Univariate continuous Mixture of (mixture + Normal)" begin
N = 50_000
seed = MersenneTwister(0)
θ₁ = -2
θ₂ = 2
σ₁ = 1
σ₂ = 1.5
θ₀ = 0.1
σ₀ = 0.2

α = 1 / 2
β = 0.3

rtol = 3e-2 # 3%
d1 = MixtureModel([Normal(θ₁, σ₁), Laplace(θ₂, σ₂)], [α, 1 - α])
d2 = Normal(θ₀, σ₀)
mix_true = MixtureModel([d1, d2], [β, 1 - β])
y = rand(seed, mix_true, N)

# We choose initial guess very close to the true solution just to show the EM algorithm convergence.
# This particular choice of mixture of mixture Gaussian with another Gaussian is non identifiable hence we execpt other solution far away from the true solution
d1_guess = MixtureModel([Normal(θ₁ - 4, σ₁ + 2), Laplace(θ₂ + 2, σ₂ - 1)], [α + 0.1, 1 - α - 0.1])
d2_guess = Normal(θ₀ + 2, 10σ₀)

mix_guess = MixtureModel([d1_guess, d2_guess], [β + 0.1, 1 - β - 0.1])
mix_mle = fit_mle(mix_guess, y; display=:iter, tol=1e-3, robust=false, infos=false)
# without print
# 1.368 s (17002715 allocations: 1.48 GiB)
# 1.485 s (17853393 allocations: 1.61 GiB)
y_guess = rand(seed, mix_mle, N)

@test probs(mix_mle) [β, 1 - β] rtol = rtol
p = params(mix_mle)[1]
@test p[1][2] [α, 1 - α] rtol = rtol
Expand Down

2 comments on commit 2c6f6b1

@dmetivie
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/76031

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.6 -m "<description of version>" 2c6f6b1f5f88727c189ca93d4aaba4003e2019a9
git push origin v0.1.6

Please sign in to comment.