Skip to content

Commit

Permalink
Add implicitly mapped measures and kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Nov 3, 2024
1 parent b95a935 commit 44cd0b1
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/MeasureBase.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module MeasureBase

using Base: @propagate_inbounds
using Base: OneTo

using Random
import Random: rand!
Expand Down Expand Up @@ -144,6 +145,7 @@ include("combinators/restricted.jl")
include("combinators/smart-constructors.jl")
include("combinators/powerweighted.jl")
include("combinators/conditional.jl")
include("combinators/implicitlymapped.jl")

include("standard/stdmeasure.jl")
include("standard/stduniform.jl")
Expand All @@ -152,6 +154,8 @@ include("standard/stdlogistic.jl")
include("standard/stdnormal.jl")
include("combinators/half.jl")

#include("implicitmaps.jl")

include("rand.jl")
include("fixedrng.jl")

Expand Down
246 changes: 246 additions & 0 deletions src/combinators/implicitlymapped.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@

"""
abstract type ImplicitlyMapped
Supertype for objects that have been mapped in an implicit way.
The explicit map/function can only be determined given some kind of observed
result `obs` using
```julia
f_map = explicit_mapfunc(mapped::ImplicitlyMapped, obs)
```
The original object that has been implicitly mapped
may be retrieved via
```julia
obj = explicit_mapfunc(mapped::ImplicitlyMapped, obs)
```
Note that `obs` is typically *not* the directly result of `f_map(ob)`. Instead,
the relationship between `obj`, `f_map`, and `obs` depends on what `obj` is:
* A measure `mu = obj`: The mapping process is equivalent to
`mapped_mu = pushfwd(f_map, mu, PushfwdRootMeasure())` and `obs` is an
element of the measurable space of `mu`. Implicitly mapped measures support
```julia
DensityInterface.DensityKind(mapped_mu::ImplicitlyMapped)
DensityInterface.logdensityof(mapped_mu::ImplicitlyMapped, obs)
```
and the explicitly mapped measure can be generated via
```julia
explicit_measure(mapped_mu::ImplicitlyMapped, obs)
```
* A transition/Markov kernel `f_kernel = obj`, i.e. a function that maps
points in some space to measures on a (possibly different) space:
The mapping process is equivalent to
`mapped_f_kernel = (p -> pushfwd(f_map, f_kernel(p), PushfwdRootMeasure()))`
and `obs` is an element of the measurable space of the measures generated
by the mapped kernel. Implicitly mapped transition/Markov kernels support
```julia
Likelihood(mapped_f_kernel::ImplicitlyMapped, obs)
```
and the explicitly mapped kernel can be generated via
```julia
explicit_measure(mapped_mu::ImplicitlyMapped, obs)
```
# Implementation
Subtypes of `ImplicitlyMapped` that should support origin measures of
type `SomeRelevantMeasure` and observations of type `SomeRelevantObs`,
resulting in explicit maps/functions of type `SomeMapFunc`, must
implement/specialize
```julia
MeasureBase.implicit_origin(mapped::MyImplicitlyMapped)
MeasureBase.explicit_mapfunc(mapped::MyImplicitlyMapped, obs::SomeRelevantObs)::SomeMapFunc
```
and (except if functions of type `SomeMapFunc` are invertible via
`InverseFunctions.inverse`) must also specialize
```julia
MeasureBase.pushfwd(f::SomeMapFunc, mu::SomeRelevantMeasure, ::PushfwdRootMeasure)
```
Subtypes of `ImplicitlyMapped` may support multiple combinations of
observation and measure types.
"""
abstract type ImplicitlyMapped end
export ImplicitlyMapped


"""
implicit_origin(mapped::ImplicitlyMapped)
Get the original object (a measure or transition/Markov kernel) that was
implicitly mapped.
See [ImplicitlyMapped](@ref) for detailed semantics.
# Implementation
`implicit_origin` must be implemented for subtypes of `ImplicitlyMapped`,
there is no default implementation.
"""
function implicit_origin end
export implicit_origin


"""
explicit_mapfunc(mapped::ImplicitlyMapped, obs)
Get an explicit map/function based on an implicitly mapped object and an
observation.
See [ImplicitlyMapped](@ref) for detailed semantics.
# Implementation
`explicit_mapfunc` must be implemented for subtypes of `ImplicitlyMapped`,
there is no default implementation.
"""
function explicit_mapfunc end
export explicit_mapfunc


"""
explicit_measure(mapped::ImplicitlyMapped, obs)
Get an explicitly mapped measure based on an implicitly mapped measure and an
observation that provides context on which pushforward to use on the onmapped
original measure `implicit_origin(mapped)`.
Used [`explicit_mapfunc`](@ref) to get the function to use in the pushforward.
# Implementation
`explicit_measure` does not need to be specialized for subtypes of
`ImplicitlyMapped`.
"""
function explicit_measure(mapped_measure::ImplicitlyMapped, obs)
f_map = explicit_mapfunc(mapped_measure, obs)
mu = implicit_origin(mapped_measure)
return pushfwd(f_map, mu, PushfwdRootMeasure())
end
export explicit_measure

function DensityInterface.logdensityof(mapped_measure::ImplicitlyMapped, obs)
return logdensityof(explicit_measure(mapped_measure, obs), obs)
end

DensityInterface.DensityKind(mapped::ImplicitlyMapped) = DensityKind(implicit_origin(mapped))

Check warning on line 141 in src/combinators/implicitlymapped.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/implicitlymapped.jl#L141

Added line #L141 was not covered by tests


"""
explicit_kernel(mapped::ImplicitlyMapped, obs)
Get an expliclity mapped transition/Markov kernel, based on an implicitly
mapped kernel and an observation that provides context on which pushforward
to add to the unmapped original kernel `implicit_origin(mapped)`.
Used [`explicit_mapfunc`](@ref) to get the function to use in the pushforward.
# Implementation
`explicit_kernel` does not need to be specialized for subtypes of
`ImplicitlyMapped`.
"""
function explicit_kernel(mapped_kernel::ImplicitlyMapped, obs)
f_map = explicit_mapfunc(mapped_kernel, obs)
f_kernel = implicit_origin(mapped_kernel)
return (p -> pushfwd(f_map, f_kernel(p), PushfwdRootMeasure()))
end
export explicit_kernel


function Likelihood(mapped_kernel::ImplicitlyMapped, obs)
return Likelihood(explicit_kernel(mapped_kernel, obs), obs)
end



"""
struct MeasureBase.TakeAny{T} <: Function
Represents a function that takes n values from a collection.
`f = TakeAny(n)` treats all collections as unordered: `f(xs) may take the
first `n` elements of `xs`, but there is no guarantee. It must, however,
always take take the same elements from collections that are identical.
Constructor: `TakeAny(n::Union{Integer,Static.StaticInteger})`.
"""
struct TakeAny{T<:IntegerLike}
n::T
end

_takeany_range(f::TakeAny, idxs) = first(idxs):first(idxs)+dynamic(f.n)-1
@inline _takeany_range(f::TakeAny, ::OneTo) = OneTo(dynamic(f.n))

@inline _takeany_range(::TakeAny{<:Static.StaticInteger{N}}, ::OneTo) where N = SOneTo(N)
@inline _takeany_range(::TakeAny{<:Static.StaticInteger{N}}, ::SOneTo) where N = SOneTo(N)

@inline (f::TakeAny)(xs::Tuple) = xs[begin:begin+f.n-1]

Check warning on line 193 in src/combinators/implicitlymapped.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/implicitlymapped.jl#L193

Added line #L193 was not covered by tests
@inline (f::TakeAny)(xs::AbstractVector) = xs[_takeany_range(f, eachindex(xs))]

function (f::TakeAny)(xs)
n = dynamic(f.n)
ys = collect(Iterators.take(xs, n))
length(ys) != n && throw(ArgumentError("Can't take $n elements from a sequence shorter than $n"))
return typeof(xs)(ys)
end



"""
struct Marginalized{T} <: ImplicitlyMapped
Represents an implicitly marginalized measure or transition kernel.
Constructors:
* `Marginalized(mu)`
* `Marginalized(f_kernel)`
See [ImplicitlyMapped](@ref) for detailed semantics.
Example:
```julia
mu = productmeasure((a = StdUniform(), b = StdNormal(), c = StdExponential()))
obs = (a = 0.7, c = 1.2)
marg_mu_equiv = productmeasure((a = StdUniform(), c = StdExponential()))
logdensityof(Marginalized(mu), obs) ≈ logdensityof(marg_mu_equiv, obs)
```
"""
struct Marginalized{T} <: ImplicitlyMapped
obj::T
end
export Marginalized

implicit_origin(mapped::Marginalized) = mapped.obj

explicit_mapfunc(::Marginalized, obs::NamedTuple{names}) where names = PropSelFunction{names,names}()
pushfwd(f::PropSelFunction, mu::ProductMeasure{<:NamedTuple}, ::PushfwdRootMeasure) = productmeasure(f(marginals(mu)))

explicit_mapfunc(::Marginalized, obs::AbstractVector) = TakeAny(length(obs))
explicit_mapfunc(::Marginalized, obs::StaticArray{Tuple{N}}) where N = TakeAny(static(N))

Check warning on line 239 in src/combinators/implicitlymapped.jl

View check run for this annotation

Codecov / codecov/patch

src/combinators/implicitlymapped.jl#L239

Added line #L239 was not covered by tests

function pushfwd(f::TakeAny, mu::PowerMeasure{<:Any,<:Tuple{<:AbstractUnitRange}}, ::PushfwdRootMeasure)
n = f.n
n_mu = length(mu)
n_mu < n && throw(ArgumentError("Can't marginalize $n_mu dimensional power measure to $n dimensions"))
mu.parent^f.n
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
AffineMaps = "2c83c9a8-abf5-4329-a0d7-deffaf474661"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
Expand Down
60 changes: 60 additions & 0 deletions test/combinators/implicitlymapped.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
using Test

using MeasureBase

using StaticArrays: SVector
using Static: static
using AffineMaps, PropertyFunctions

@testset "implicitlymapped" begin
@testset "TakeAny" begin
V = [3, 2, 4, 2, 7, 5, 6]
mV = [3, 2, 4, 2]
S = Set(V)
mS = Set(mV)
SV = SVector(V...)
mSV = SVector(mV...)

@test @inferred(MeasureBase.TakeAny(4)(V)) == mV
@test @inferred(MeasureBase.TakeAny(static(4))(V)) == mV
tS = @inferred(MeasureBase.TakeAny(4)(S))
@test tS isa Set && length(tS) == 4 && all(x -> x in S, tS)
@test @inferred(MeasureBase.TakeAny(static(4))(S)) == MeasureBase.TakeAny(4)(S)
@test @inferred(MeasureBase.TakeAny(static(4))(SV)) === mSV
@test @inferred(MeasureBase.TakeAny(4)(SV)) == mV
@test @inferred(MeasureBase.TakeAny(4)(V)) == mV
end

function test_implicitly_mapped(label, f_kernel, ref_mapfunc, ref_mappedkernel, par, orig_obs, obs)
@testset "$label" begin
im_measure = @inferred Marginalized(f_kernel(par))
im_kernel = @inferred Marginalized(f_kernel)
mapfunc = @inferred explicit_mapfunc(im_measure, obs)
mapped_measure = @inferred explicit_measure(im_measure, obs)
mapped_likelihood = @inferred Likelihood(im_kernel, obs)

@test mapfunc == ref_mapfunc
@test @inferred(mapfunc(orig_obs)) == obs
@test mapped_measure == ref_mappedkernel(par)

@test @inferred(logdensityof(im_measure, obs)) logdensityof(mapped_measure, obs)
@test @inferred(logdensityof(mapped_likelihood, par)) logdensityof(Likelihood(ref_mappedkernel, obs), par)
end
end

f_kernel = par -> productmeasure(map(m -> pushfwd(Mul(par), m), (a = StdUniform(), b = StdNormal(), c = StdExponential())))
ref_mapfunc = @pf (;$a, $c)
ref_mappedkernel = par -> productmeasure(map(m -> pushfwd(Mul(par), m), (a = StdUniform(), c = StdExponential())))
par = 4.2
orig_obs = (a = 0.7, b = 2.1, c = 1.2)
obs = (a = 0.7, c = 1.2)
test_implicitly_mapped("marginalized nt", f_kernel, ref_mapfunc, ref_mappedkernel, par, orig_obs, obs)

f_kernel = par -> pushfwd(Mul(par), StdNormal())^7
ref_mapfunc = MeasureBase.TakeAny(3)
ref_mappedkernel = par -> pushfwd(Mul(par), StdNormal())^3
par = 4.2
orig_obs = [9.4, -7.3, 1.0, -2.9, 1.9, 4.7, 0.5]
obs = [9.4, -7.3, 1.0]
test_implicitly_mapped("marginalized nt", f_kernel, ref_mapfunc, ref_mappedkernel, par, orig_obs, obs)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ include("smf.jl")

include("combinators/weighted.jl")
include("combinators/transformedmeasure.jl")
include("combinators/implicitlymapped.jl")

include("test_docs.jl")

0 comments on commit 44cd0b1

Please sign in to comment.