Skip to content

Commit

Permalink
Merge gpu code
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez committed Oct 27, 2024
1 parent b78efce commit c791095
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 50 deletions.
2 changes: 2 additions & 0 deletions src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ using NNlib

abstract type AbstractFilter <: AbstractSampler end

abstract type AbstractParticleFilter{N} <: AbstractFilter end

"""
predict([rng,] model, alg, iter, state; kwargs...)
Expand Down
15 changes: 7 additions & 8 deletions src/algorithms/apf.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
export AuxiliaryParticleFilter, APF

mutable struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: AbstractFilter
N::Integer
mutable struct AuxiliaryParticleFilter{N,RS<:AbstractConditionalResampler} <: AbstractParticleFilter{N}
resampler::RS
aux::Vector # Auxiliary weights
end
Expand All @@ -10,20 +9,20 @@ function AuxiliaryParticleFilter(
N::Integer; threshold::Real=0., resampler::AbstractResampler=Systematic()
)
conditional_resampler = ESSResampler(threshold, resampler)
return AuxiliaryParticleFilter(N, conditional_resampler, zeros(N))
return AuxiliaryParticleFilter{N,typeof(conditional_resampler)}(conditional_resampler, zeros(N))
end

const APF = AuxiliaryParticleFilter

function initialise(
rng::AbstractRNG,
model::StateSpaceModel{T},
filter::AuxiliaryParticleFilter;
filter::AuxiliaryParticleFilter{N},
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
) where {T}
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N))
initial_weights = fill(-log(T(filter.N)), filter.N)
) where {N,T}
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N)
initial_weights = zeros(T, N)

return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter)
end
Expand Down Expand Up @@ -86,7 +85,7 @@ function update(
states.filtered.log_weights = states.proposed.log_weights + log_increments
states.filtered.particles = states.proposed.particles

return (states, logsumexp(log_increments) - log(T(filter.N)))
return states, logmarginal(states, filter)
end

function step(
Expand Down
22 changes: 7 additions & 15 deletions src/algorithms/bootstrap.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,28 @@
export BootstrapFilter, BF

struct BootstrapFilter{RS<:AbstractResampler} <: AbstractFilter
N::Integer
struct BootstrapFilter{N,RS<:AbstractResampler} <: AbstractParticleFilter{N}
resampler::RS
end

function BootstrapFilter(
N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
)
conditional_resampler = ESSResampler(threshold, resampler)
return BootstrapFilter(N, conditional_resampler)
return BootstrapFilter{N, typeof(conditional_resampler)}(conditional_resampler)
end

"""Shorthand for `BootstrapFilter`"""
const BF = BootstrapFilter

function BootstrapFilter(
N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
)
conditional_resampler = ESSResampler(threshold, resampler)
return BootstrapFilter(N, conditional_resampler)
end

function initialise(
rng::AbstractRNG,
model::StateSpaceModel{T},
filter::BootstrapFilter;
filter::BootstrapFilter{N};
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
) where {T}
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N))
initial_weights = zeros(T, filter.N)
) where {N,T}
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N)
initial_weights = zeros(T, N)

return update_ref!(
ParticleContainer(initial_states, initial_weights), ref_state, filter
Expand Down Expand Up @@ -71,7 +63,7 @@ function update(
states.filtered.log_weights = states.proposed.log_weights + log_increments
states.filtered.particles = states.proposed.particles

return states, logmarginal(states)
return states, logmarginal(states, filter)
end

function reset_weights!(
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/rbpf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ function update(

states.filtered.log_weights = states.proposed.log_weights + log_increments

return states, logmarginal(states)
return states, logmarginal(states, algo)
end

#################################
Expand Down
6 changes: 5 additions & 1 deletion src/containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,15 @@ Base.keys(state::ParticleState) = LinearIndices(state.particles)
Base.@propagate_inbounds Base.getindex(state::ParticleState, i) = state.particles[i]
# Base.@propagate_inbounds Base.getindex(state::ParticleState, i::Vector{Int}) = state.particles[i]

function reset_weights!(state::ParticleState{T,WT}) where {T,WT<:Real}
function reset_weights!(state::ParticleState{T,WT}, idx, ::AbstractFilter) where {T,WT<:Real}
fill!(state.log_weights, zero(WT))
return state.log_weights
end

function logmarginal(states::ParticleContainer, ::AbstractFilter)
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights)
end

function update_ref!(
pc::ParticleContainer{T},
ref_state::Union{Nothing,AbstractVector{T}},
Expand Down
33 changes: 8 additions & 25 deletions src/resamplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@ function resample(
rng::AbstractRNG,
resampler::AbstractResampler,
states::ParticleState{PT,WT},
filter::AbstractFilter,
filter::AbstractFilter;
weights::AbstractVector{WT}=StatsBase.weights(states)
) where {PT,WT}
weights = StatsBase.weights(states)
idxs = sample_ancestors(rng, resampler, weights)

new_state = ParticleState(deepcopy(states.particles[idxs]), zeros(WT, length(states)))

new_state = ParticleState(deepcopy(states.particles[idxs]), zeros(WT, length(states)))
reset_weights!(new_state, idxs, filter)
return new_state, idxs
end

Expand All @@ -26,8 +25,9 @@ function resample(
rng::AbstractRNG,
resampler::AbstractResampler,
states::RaoBlackwellisedParticleState{T,M,ZT},
::AbstractFilter;
weights=StatsBase.weights(states)
) where {T,M,ZT}
weights = StatsBase.weights(states)
idxs = sample_ancestors(rng, resampler, weights)

new_state = RaoBlackwellisedParticleState(
Expand All @@ -39,23 +39,6 @@ function resample(
return new_state, idxs
end

# TODO: combine this with above definition
function resample(
rng::AbstractRNG,
resampler::AbstractResampler,
states::RaoBlackwellisedParticleState{T,M,ZT},
) where {T,M,ZT}
weights = StatsBase.weights(states)
idxs = sample_ancestors(rng, resampler, weights)

new_state = RaoBlackwellisedParticleState(
deepcopy(states.x_particles[:, idxs]),
deepcopy(states.z_particles[idxs]),
CUDA.zeros(T, length(states)),
)
return reset_weights!(state, idxs, filter)
end

## CONDITIONAL RESAMPLING ##################################################################

abstract type AbstractConditionalResampler <: AbstractResampler end
Expand All @@ -69,7 +52,7 @@ struct ESSResampler <: AbstractConditionalResampler
end

function resample(
rng::AbstractRNG, cond_resampler::ESSResampler, state::ParticleState{PT,WT}
rng::AbstractRNG, cond_resampler::ESSResampler, state::ParticleState{PT,WT}, filter::AbstractFilter
) where {PT,WT}
n = length(state)
# TODO: computing weights twice. Should create a wrapper to avoid this
Expand All @@ -78,7 +61,7 @@ function resample(
@debug "ESS: $ess"

if cond_resampler.threshold * n ess
return resample(rng, cond_resampler.resampler, state)
return resample(rng, cond_resampler.resampler, state, filter; weights=weights)
else
return deepcopy(state), collect(1:n)
end
Expand Down

0 comments on commit c791095

Please sign in to comment.