Skip to content

Commit

Permalink
emcee fix
Browse files Browse the repository at this point in the history
  • Loading branch information
montyvesselinov committed Aug 5, 2024
1 parent 987eb4c commit f458c30
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions src/MadsMonteCarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ function loadecmeeresults(madsdata::AbstractDict, filename::AbstractString)
return nothing, nothing, nothing
end
chain, llhoods, params, obs = JLD2.load(filename, "chain", "llhoods", "params", "obs")
@info("AffineInvariantMCMC results loaded: number of parameters = $(size(chain, 1)); number of realizations = $(size(chain, 2))")
@info("AffineInvariantMCMC results loaded: number of parameters = $(size(chain, 1)); number of observation = $(size(obs, 1)); number of realizations = $(size(chain, 2))")
if !Mads.jld2haskey(filename, "observations", "params", "obs"; quiet=false)
@warn("$(filename) does not contain AffineInvariantMCMC observation results!")
return chain, llhoods, nothing
end
observations, params, obs = JLD2.load(filename, "observations", "params", "obs")
flag_bad_data = false
if size(observations, 1) != length(Mads.getobskeys(madsdata))
if size(observations, 1) != length(Mads.getobskeys(madsdata))
@warn("Different number of observations (Mads $(length(Mads.getobskeys(madsdata))) vs Input $(size(observations, 1)))!")
flag_bad_data = true
end
Expand Down Expand Up @@ -81,7 +81,7 @@ function loadecmeeresults(madsdata::AbstractDict, filename::AbstractString)
end
end

function emceesampling(madsdata::AbstractDict; filename::AbstractString="", load::Bool=false, save::Bool=false, execute::Bool=true, numwalkers::Integer=10, sigma::Number=0.01, seed::Integer=-1, rng::Union{Nothing,Random.AbstractRNG,DataType}=nothing, kw...)
function emceesampling(madsdata::AbstractDict; filename::AbstractString="", load::Bool=false, save::Bool=false, execute::Bool=true, numwalkers::Integer=10, nexecutions::Integer=100, burnin::Integer=numwalkers, thinning::Integer=10, sigma::Number=0.01, seed::Integer=-1, rng::Union{Nothing,Random.AbstractRNG,DataType}=nothing, kw...)
if filename != ""
load = save = true
end
Expand Down Expand Up @@ -122,7 +122,7 @@ function emceesampling(madsdata::AbstractDict; filename::AbstractString="", load
p0[i, j] = pmin[i] + rand(Mads.rng, d) * (pmax[i] - pmin[i])
end
end
chain, llhoods, observations = emceesampling(madsdata, p0; filename=filename, load=load, save=save, execute=execute, numwalkers=numwalkers, seed=seed, rng=rng, kw...)
chain, llhoods, observations = emceesampling(madsdata, p0; filename=filename, load=load, save=save, execute=execute, numwalkers=numwalkers, nexecutions=nexecutions, burnin=burnin, thinning=thinning, seed=seed, rng=rng, kw...)
return chain, llhoods, observations
end
function emceesampling(madsdata::AbstractDict, p0::AbstractMatrix; filename::AbstractString="", load::Bool=false, save::Bool=false, execute::Bool=true, numwalkers::Integer=10, nexecutions::Integer=100, burnin::Integer=numwalkers, thinning::Integer=10, seed::Integer=-1, weightfactor::Number=1.0, rng::Union{Nothing,Random.AbstractRNG,DataType}=nothing, distributed_function::Bool=false)
Expand All @@ -144,12 +144,19 @@ function emceesampling(madsdata::AbstractDict, p0::AbstractMatrix; filename::Abs
@warn("Empty!")
bad_data = true
else
if size(chain, 2) != div(numsamples_perwalker, thinning) * numwalkers
if size(chain, 2) != div(nexecutions, thinning)
@warn("Preexisting data does not match the number of walkers and steps!")
bad_data = true
elseif size(chain, 1) != no
poop
elseif size(chain, 1) != np
@warn("Preexisting data does not match the number of parameters!")
bad_data = true
elseif size(observations, 1) != no
@warn("Preexisting data does not match the number of observations!")
bad_data = true
elseif size(observations, 2) != size(chain, 2) != length(llhoods)
@warn("Preexisting data dimension do not match!")
bad_data = true
end
end
if bad_data
Expand Down Expand Up @@ -179,7 +186,7 @@ function emceesampling(madsdata::AbstractDict, p0::AbstractMatrix; filename::Abs
numsamples = numsamples_perwalker * numwalkers
@info("AffineInvariantMCMC exploration stage (total number of executions $(numsamples), final chain size $(div(numsamples_perwalker, thinning) * numwalkers))...")
chain, llhoods = AffineInvariantMCMC.sample(arrayloglikelihood, numwalkers, burninchain[:, :, end], numsamples_perwalker, thinning; filename="", load=false, save=false, rng=Mads.rng,)
chain, llhoods = AffineInvariantMCMC.flattenmcmcarray(chain, llhoods)
chain, llhoods = AffineInvariantMCMC.flattenmcmcarray(chain, llhoods)
if save
madsinfo("Saving AffineInvariantMCMC results in $(filename) ...")
JLD2.save(filename, "chain", chain, "llhoods", llhoods, "params", Mads.getoptparamkeys(madsdata), "obs", Mads.getobskeys(madsdata))
Expand Down Expand Up @@ -358,4 +365,4 @@ Returns:
"""
function paramdict2array(dict::AbstractDict)
return permutedims(hcat(map(i->collect(dict[i]), collect(keys(dict)))...))
end
end

0 comments on commit f458c30

Please sign in to comment.