Skip to content

Commit

Permalink
Parallelize multinomial sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwilliam77 committed Jan 5, 2021
1 parent 9c977b0 commit bafa58a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
27 changes: 18 additions & 9 deletions src/resample.jl
Original file line number Diff line number Diff line change
@@ -1,38 +1,47 @@
"""
```
resample(weights::AbstractArray; method::Symbol = :systematic)
resample(weights::AbstractArray; n_parts::Int64 = length(weights),
method::Symbol = :systematic, parallel::Bool = false)
```
Reindexing and reweighting samples from a degenerate distribution
### Arguments:
### Input Arguments
- `weights`: wtsim[:,i]
the weights of a degenerate distribution.
### Keyword Arguments
- `n_parts`: length(weights)
the desired length of output vector
- `method`: :systematic, :multinomial, or :polyalgo
the method for resampling
- `parallel`: if true, mulitnomial sampling will be done in parallel.
### Output:
- `indx`: the newly assigned indices of parameter draws.
"""
function resample(weights::Vector{Float64}; n_parts::Int64 = length(weights),
method::Symbol = :systematic)
#n_parts = length(weights)
method::Symbol = :systematic, parallel::Bool = false)

if method == :multinomial
indx = Vector{Int64}(undef, n_parts)

# Stores cumulative weights until given index
cumulative_weights = cumsum(weights ./ sum(weights))
offset = rand(n_parts)

# TODO: parallelize
for i in 1:n_parts
indx[i] = findfirst(x -> offset[i] < x, cumulative_weights)
if parallel
indx = @sync @distributed (vcat) for i in 1:n_parts
findfirst(x -> offset[i] < x, cumulative_weights)
end
else
indx = Vector{Int64}(undef, n_parts)

for i in 1:n_parts
indx[i] = findfirst(x -> offset[i] < x, cumulative_weights)
end
end
return indx

return indx
elseif method == :systematic
# Stores cumulative weights until given index
cumulative_weights = cumsum(weights ./ sum(weights))
Expand Down
5 changes: 3 additions & 2 deletions src/smc_main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ function smc(loglikelihood::Function, parameters::ParameterVector{U}, data::Matr
n_to_resample = Int(round((1-tempered_update_prior_weight) * n_parts))
n_from_prior = n_parts - n_to_resample
new_inds = resample(get_weights(cloud); n_parts = n_to_resample,
method = resampling_method)
method = resampling_method, parallel = parallel)

bridge_cloud = Cloud(n_para, n_to_resample)
update_cloud!(bridge_cloud, cloud.particles[new_inds, :])
Expand Down Expand Up @@ -374,7 +374,8 @@ function smc(loglikelihood::Function, parameters::ParameterVector{U}, data::Matr
if (cloud.ESS[i] < threshold)

# Resample according to particle weights, uniformly reset weights to 1/n_parts
new_inds = resample(normalized_weights/n_parts; method = resampling_method)
new_inds = resample(normalized_weights/n_parts; method = resampling_method,
parallel = parallel)
cloud.particles = [deepcopy(cloud.particles[k,j]) for k in new_inds,
j=1:size(cloud.particles, 2)]
reset_weights!(cloud)
Expand Down

0 comments on commit bafa58a

Please sign in to comment.