Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
montyvesselinov committed Mar 21, 2024
1 parent 7d7496c commit c53771f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src-old/MadsEmcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Example:
Mads.emcee(llhood, numwalkers=10, numsamples_perwalker=100, thinning=1)
```
"""
function sample(llhood::Function, numwalkers::Integer, x0::AbstractMatrix, numsamples_perwalker::Integer, thinning::Integer, a::Number=2.; filename::AbstractString="", load::Bool=true, save::Bool=true, rng::Random.AbstractRNG=Random.GLOBAL_RNG)
function sample(llhood::Function, numwalkers::Integer, x0::AbstractMatrix, numsamples_perwalker::Integer, thinning::Integer=numwalkers, a::Number=2.; filename::AbstractString="", load::Bool=true, save::Bool=true, rng::Random.AbstractRNG=Random.GLOBAL_RNG)
if filename != "" && isfile(filename) && load
chain, llhoodvals = JLD2.load(filename, "chain", "llhoods")
end
Expand Down
21 changes: 14 additions & 7 deletions src/MadsMonteCarlo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ function loadecmeeresults(madsdata::AbstractDict, filename::AbstractString)
return chain, llhoods, observations
end
else
@warn("AffineInvariantMCMC results file $(filename) is missing!")
return nothing, nothing, nothing
end
end
Expand Down Expand Up @@ -116,17 +117,20 @@ function emceesampling(madsdata::AbstractDict; filename::AbstractString="", load
return chain, llhoods, observations
end

function emceesampling(madsdata::AbstractDict, p0::AbstractMatrix; filename::AbstractString="", load::Bool=true, save::Bool=true, execute::Bool=true, numwalkers::Integer=10, nsteps::Integer=100, burnin::Integer=10, thinning::Integer=1, seed::Integer=-1, weightfactor::Number=1.0, rng::Union{Nothing,Random.AbstractRNG,DataType}=nothing, distributed_function::Bool=false)
function emceesampling(madsdata::AbstractDict, p0::AbstractMatrix; filename::AbstractString="", load::Bool=true, save::Bool=true, execute::Bool=true, numwalkers::Integer=10, nexecutions::Integer=100, burnin::Integer=10, thinning::Integer=numwalkers, seed::Integer=-1, weightfactor::Number=1.0, rng::Union{Nothing,Random.AbstractRNG,DataType}=nothing, distributed_function::Bool=false)
numsamples_perwalker = div(nexecutions, numwalkers)
if load && filename != ""
bad_data = false
@info("AffineInvariantMCMC preexisting data loading $(filename) ...")
chain, llhoods, observations = loadecmeeresults(madsdata, filename)
if isnothing(chain)
@warn("Empty!")
bad_data = true
else
if size(chain, 2) != numwalkers * nsteps
if size(chain, 2) != div(numsamples_perwalker, thinning) * numwalkers
@warn("Preexisting data does not match the number of walkers and steps!")
bad_data = true
elseif size(chain,12) != length(Mads.getoptparamkeys(madsdata))
elseif size(chain, 1) != length(Mads.getoptparamkeys(madsdata))
@warn("Preexisting data does not match the number of parameters!")
bad_data = true
end
Expand All @@ -151,10 +155,13 @@ function emceesampling(madsdata::AbstractDict, p0::AbstractMatrix; filename::Abs
@Distributed.everywhere arrayloglikelihood_distributed = Mads.makearrayloglikelihood($madsdata, $madsloglikelihood)
arrayloglikelihood = (x)->Core.eval(Main, :arrayloglikelihood_distributed)(x)
end
@info("AffineInvariantMCMC burning stage ...")
burninchain, _ = AffineInvariantMCMC.sample(arrayloglikelihood, numwalkers, p0, div(burnin, numwalkers), 1; filename="", rng=Mads.rng, save=false)
@info("AffineInvariantMCMC exploration stage ...")
chain, llhoods = AffineInvariantMCMC.sample(arrayloglikelihood, numwalkers, burninchain[:, :, end], div(nsteps, numwalkers), thinning; filename="", rng=Mads.rng, save=false)
numsamples_perwalker_burnin = div(burnin, numwalkers)
numsamples = numsamples_perwalker_burnin * numwalkers
@info("AffineInvariantMCMC burning stage (total number of executions $(numsamples), final burning chain size $(numsamples_perwalker_burnin * numwalkers))...")
burninchain, _ = AffineInvariantMCMC.sample(arrayloglikelihood, numwalkers, p0, numsamples_perwalker_burnin, 1; filename="", load=false, save=false, rng=Mads.rng)
numsamples = numsamples_perwalker * numwalkers
@info("AffineInvariantMCMC exploration stage (total number of executions $(numsamples), final chain size $(div(numsamples_perwalker, thinning) * numwalkers))...")
chain, llhoods = AffineInvariantMCMC.sample(arrayloglikelihood, numwalkers, burninchain[:, :, end], numsamples_perwalker, thinning; filename="", load=false, save=false, rng=Mads.rng,)
chain, llhoods = AffineInvariantMCMC.flattenmcmcarray(chain, llhoods)
if save && filename != ""
JLD2.save(filename, "chain", chain, "llhoods", llhoods, "params", Mads.getoptparamkeys(madsdata), "obs", Mads.getobskeys(madsdata))
Expand Down

0 comments on commit c53771f

Please sign in to comment.