Skip to content

Commit

Permalink
sensivity parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
montyvesselinov committed Mar 14, 2024
1 parent 84e7d34 commit c051730
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
27 changes: 16 additions & 11 deletions src/MadsForward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,33 +86,38 @@ function forward(madsdata::AbstractDict, paramarray::AbstractArray; parallel::Bo
restartdir = getrestartdir(madsdata_c)
if checkpointfrequency != 0 && restartdir != ""
@info("RobustPmap for parallel execution of forward runs with checkpoint frequency...")
if s[2] == np
rv = RobustPmap.crpmap(i->func_forward(vec(paramarray[i, :])), checkpointfrequency, joinpath(restartdir, checkpointfilename), 1:ncases)
else
if s[1] == np
rv = RobustPmap.crpmap(i->func_forward(vec(paramarray[:, i])), checkpointfrequency, joinpath(restartdir, checkpointfilename), 1:ncases)
else
rv = RobustPmap.crpmap(i->func_forward(vec(paramarray[i, :])), checkpointfrequency, joinpath(restartdir, checkpointfilename), 1:ncases)
end
r = hcat(collect.(values.(rv))...)
elseif parallel && Distributed.nprocs() > 1
if robustpmap
@info("RobustPmap for parallel execution of forward runs ...")
if s[2] == np
rv = RobustPmap.rpmap(func_forward, collect(paramarray))
else
if s[1] == np
# @show paramarray[:, 1]
# @show collect(values(func_forward(vec(paramarray[:, 1]))))
rv = RobustPmap.rpmap(func_forward, permutedims(collect(paramarray)))
else
# @show paramarray[1, :]
# @show collect(values(func_forward(vec(paramarray[1, :]))))
rv = RobustPmap.rpmap(func_forward, collect(paramarray))
end
r = hcat(collect.(values.(rv))...)
else
@info("Parallel execution of forward runs ...")
if s[2] == np
rv1 = collect(values(func_forward(vec(paramarray[1, :]))))
psa = collect(paramarray) # collect to avoid issues if paramarray is a SharedArray
else
if s[1] == np
rv1 = collect(values(func_forward(vec(paramarray[:, 1]))))
psa = permutedims(collect(paramarray)) # collect to avoid issues if paramarray is a SharedArray

else
rv1 = collect(values(func_forward(vec(paramarray[1, :]))))
psa = collect(paramarray) # collect to avoid issues if paramarray is a SharedArray
end
r = SharedArrays.SharedArray{Float64}(length(rv1), ncases)
r[:, 1] = rv1
Distributed.@everywhere madsdata_c = $madsdata_c
@Distributed.everywhere madsdata_c = $madsdata_c
@sync @Distributed.distributed for i = 2:ncases
func_forward = Mads.makearrayfunction(madsdata_c) # this is needed to avoid issues with the closure
r[:, i] = collect(values(func_forward(vec(psa[i, :]))))
Expand Down
31 changes: 24 additions & 7 deletions src/MadsSensitivityAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ keytext=Dict("N"=>"number of samples [default=`100`]",
"restartdir"=>"directory where files will be stored containing model results for the efast simulation restarts [default=`\"efastcheckpoints\"`]",
"restart"=>"save restart information [default=`false`]")))
"""
function efast(md::AbstractDict; N::Integer=100, M::Integer=6, gamma::Number=4, seed::Integer=-1, checkpointfrequency::Integer=N, save::Bool=true, load::Bool=false, execute::Bool=true, restartdir::AbstractString="efastcheckpoints", restart::Bool=false, rng::Union{Nothing,Random.AbstractRNG,DataType}=nothing)
function efast(md::AbstractDict; N::Integer=100, M::Integer=6, gamma::Number=4, seed::Integer=-1, checkpointfrequency::Integer=N, save::Bool=true, load::Bool=false, execute::Bool=true, parallel::Bool=true, robustpmap::Bool=false, restartdir::AbstractString="efastcheckpoints", restart::Bool=false, rng::Union{Nothing,Random.AbstractRNG,DataType}=nothing)
issvr = false
# a: Sensitivity of each Sobol parameter (low: very sensitive, high; not sensitive)
# A and B: Real & Imaginary components of Fourier coefficients, respectively. Used to calculate sensitivty.
Expand Down Expand Up @@ -1357,12 +1357,29 @@ function efast(md::AbstractDict; N::Integer=100, M::Integer=6, gamma::Number=4,
else
=#
# If # of processors is > Nr*nprime+(Nr+1) compute model output in parallel
Mads.madsoutput("Compute model output in parallel ... $(P) > $(Nr*nprime+(Nr+1)) ...\n")
Mads.madsoutput("Computing models in parallel - Parameter k = $k ($(paramkeys[k])) ...\n")
if restart
Y = hcat(RobustPmap.crpmap(i->collect(values(f(merge(paramalldict, OrderedCollections.OrderedDict{String, Float64}(zip(paramkeys, X[i, :])))))), checkpointfrequency, joinpath(restartdir, "efast_$(kL)_$k"), 1:size(X, 1))...)'
else
Y = hcat(RobustPmap.rpmap(i->collect(values(f(merge(paramalldict, OrderedCollections.OrderedDict{String, Float64}(zip(paramkeys, X[i, :])))))), 1:size(X, 1))...)'
Mads.madsoutput("Compute model outputs ... $(P) > $(Nr*nprime+(Nr+1)) ...\n")
Mads.madsoutput("Computing model for Parameter k = $k ($(paramkeys[k])) ...\n")
if robustpmap
if restart
@info("RobustPmap for parallel execution of forward runs with restart ...")
m = RobustPmap.crpmap(i->collect(values(f(merge(paramalldict, OrderedCollections.OrderedDict{String, Float64}(zip(paramkeys, X[i, :])))))), checkpointfrequency, joinpath(restartdir, "efast_$(kL)_$k"), 1:size(X, 1))

else
@info("RobustPmap for parallel execution of forward runs without restart ...")
m = RobustPmap.rpmap(i->collect(values(f(merge(paramalldict, OrderedCollections.OrderedDict{String, Float64}(zip(paramkeys, X[i, :])))))), 1:size(X, 1))
end
Y = permutedims(hcat(m...))
elseif parallel && Distributed.nprocs() > 1
rv1 = collect(values(f(merge(paramalldict, OrderedCollections.OrderedDict{String, Float64}(zip(paramkeys, X[1, :]))))))
psa = collect(X) # collect to avoid issues if paramarray is a SharedArray
r = SharedArrays.SharedArray{Float64}(length(rv1), size(X, 1))
r[:, 1] = rv1
@Distributed.everywhere md = $md
@sync @Distributed.distributed for i = 2:size(X, 1)
func_forward = Mads.makemadscommandfunction(md) # this is needed to avoid issues with the closure
r[:, i] = collect(values(func_forward(merge(paramalldict, OrderedCollections.OrderedDict{String, Float64}(zip(paramkeys, X[i, :]))))))
end
Y = permutedims(collect(r))
end
#end #End if (processors)

Expand Down

0 comments on commit c051730

Please sign in to comment.