From f458c3094bae33819ace8210cb321f5835db8d15 Mon Sep 17 00:00:00 2001 From: "Velimir (monty) Vesselinov" Date: Sun, 4 Aug 2024 20:46:50 -0600 Subject: [PATCH] emcee fix --- src/MadsMonteCarlo.jl | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/MadsMonteCarlo.jl b/src/MadsMonteCarlo.jl index 0c9e2289..2f444eb3 100644 --- a/src/MadsMonteCarlo.jl +++ b/src/MadsMonteCarlo.jl @@ -32,14 +32,14 @@ function loadecmeeresults(madsdata::AbstractDict, filename::AbstractString) return nothing, nothing, nothing end chain, llhoods, params, obs = JLD2.load(filename, "chain", "llhoods", "params", "obs") - @info("AffineInvariantMCMC results loaded: number of parameters = $(size(chain, 1)); number of realizations = $(size(chain, 2))") + @info("AffineInvariantMCMC results loaded: number of parameters = $(size(chain, 1)); number of observation = $(size(obs, 1)); number of realizations = $(size(chain, 2))") if !Mads.jld2haskey(filename, "observations", "params", "obs"; quiet=false) @warn("$(filename) does not contain AffineInvariantMCMC observation results!") return chain, llhoods, nothing end observations, params, obs = JLD2.load(filename, "observations", "params", "obs") flag_bad_data = false - if size(observations, 1) != length(Mads.getobskeys(madsdata)) + if size(observations, 1) != length(Mads.getobskeys(madsdata)) @warn("Different number of observations (Mads $(length(Mads.getobskeys(madsdata))) vs Input $(size(observations, 1)))!") flag_bad_data = true end @@ -81,7 +81,7 @@ function loadecmeeresults(madsdata::AbstractDict, filename::AbstractString) end end -function emceesampling(madsdata::AbstractDict; filename::AbstractString="", load::Bool=false, save::Bool=false, execute::Bool=true, numwalkers::Integer=10, sigma::Number=0.01, seed::Integer=-1, rng::Union{Nothing,Random.AbstractRNG,DataType}=nothing, kw...) +function emceesampling(madsdata::AbstractDict; filename::AbstractString="", load::Bool=false, save::Bool=false, execute::Bool=true, numwalkers::Integer=10, nexecutions::Integer=100, burnin::Integer=numwalkers, thinning::Integer=10, sigma::Number=0.01, seed::Integer=-1, rng::Union{Nothing,Random.AbstractRNG,DataType}=nothing, kw...) if filename != "" load = save = true end @@ -122,7 +122,7 @@ function emceesampling(madsdata::AbstractDict; filename::AbstractString="", load p0[i, j] = pmin[i] + rand(Mads.rng, d) * (pmax[i] - pmin[i]) end end - chain, llhoods, observations = emceesampling(madsdata, p0; filename=filename, load=load, save=save, execute=execute, numwalkers=numwalkers, seed=seed, rng=rng, kw...) + chain, llhoods, observations = emceesampling(madsdata, p0; filename=filename, load=load, save=save, execute=execute, numwalkers=numwalkers, nexecutions=nexecutions, burnin=burnin, thinning=thinning, seed=seed, rng=rng, kw...) return chain, llhoods, observations end function emceesampling(madsdata::AbstractDict, p0::AbstractMatrix; filename::AbstractString="", load::Bool=false, save::Bool=false, execute::Bool=true, numwalkers::Integer=10, nexecutions::Integer=100, burnin::Integer=numwalkers, thinning::Integer=10, seed::Integer=-1, weightfactor::Number=1.0, rng::Union{Nothing,Random.AbstractRNG,DataType}=nothing, distributed_function::Bool=false) @@ -144,12 +144,19 @@ function emceesampling(madsdata::AbstractDict, p0::AbstractMatrix; filename::Abs @warn("Empty!") bad_data = true else - if size(chain, 2) != div(numsamples_perwalker, thinning) * numwalkers + if size(chain, 2) != div(nexecutions, thinning) @warn("Preexisting data does not match the number of walkers and steps!") bad_data = true - elseif size(chain, 1) != no + poop + elseif size(chain, 1) != np @warn("Preexisting data does not match the number of parameters!") bad_data = true + elseif size(observations, 1) != no + @warn("Preexisting data does not match the number of observations!") + bad_data = true + elseif size(observations, 2) != size(chain, 2) != length(llhoods) + @warn("Preexisting data dimension do not match!") + bad_data = true end end if bad_data @@ -179,7 +186,7 @@ function emceesampling(madsdata::AbstractDict, p0::AbstractMatrix; filename::Abs numsamples = numsamples_perwalker * numwalkers @info("AffineInvariantMCMC exploration stage (total number of executions $(numsamples), final chain size $(div(numsamples_perwalker, thinning) * numwalkers))...") chain, llhoods = AffineInvariantMCMC.sample(arrayloglikelihood, numwalkers, burninchain[:, :, end], numsamples_perwalker, thinning; filename="", load=false, save=false, rng=Mads.rng,) - chain, llhoods = AffineInvariantMCMC.flattenmcmcarray(chain, llhoods) + chain, llhoods = AffineInvariantMCMC.flattenmcmcarray(chain, llhoods) if save madsinfo("Saving AffineInvariantMCMC results in $(filename) ...") JLD2.save(filename, "chain", chain, "llhoods", llhoods, "params", Mads.getoptparamkeys(madsdata), "obs", Mads.getobskeys(madsdata)) @@ -358,4 +365,4 @@ Returns: """ function paramdict2array(dict::AbstractDict) return permutedims(hcat(map(i->collect(dict[i]), collect(keys(dict)))...)) -end +end \ No newline at end of file