Skip to content

Commit

Permalink
Specialize logdensityof for DensityMeasure
Browse files Browse the repository at this point in the history
Ensures proper type propagation (until future refactor of density calculation
engine).
  • Loading branch information
oschulz committed Nov 4, 2024
1 parent 2131ac3 commit 500ed55
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 7 deletions.
19 changes: 19 additions & 0 deletions ext/MeasureBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@ using ChainRulesCore: NoTangent, ZeroTangent
import ChainRulesCore


# = utils ====================================================================

using MeasureBase: isneginf, isposinf

_isneginf_pullback(::Any) = (NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(isneginf), x) = isneginf(x), _logdensityof_rt_pullback

Check warning on line 15 in ext/MeasureBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasureBaseChainRulesCoreExt.jl#L14-L15

Added lines #L14 - L15 were not covered by tests

_isposinf_pullback(::Any) = (NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(isposinf), x) = isposinf(x), _isposinf_pullback

Check warning on line 18 in ext/MeasureBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasureBaseChainRulesCoreExt.jl#L17-L18

Added lines #L17 - L18 were not covered by tests


# = insupport & friends ======================================================

using MeasureBase:
Expand Down Expand Up @@ -44,4 +55,12 @@ _check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback


# = return type inference ====================================================

using MeasureBase: logdensityof_rt

_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v) = logdensityof_rt(target, v), _logdensityof_rt_pullback

Check warning on line 63 in ext/MeasureBaseChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasureBaseChainRulesCoreExt.jl#L62-L63

Added lines #L62 - L63 were not covered by tests


end # module MeasureBaseChainRulesCoreExt
4 changes: 4 additions & 0 deletions src/density-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ To compute a log-density relative to a specific base-measure, see
_checksupport(insupport(μ, x), result)
end

@inline function logdensityof_rt(::T, ::U) where {T,U}
Core.Compiler.return_type(logdensityof, Tuple{T,U})
end

_checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf))


Expand Down
19 changes: 19 additions & 0 deletions src/density.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,25 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x)

density_def::DensityMeasure, x) = densityof.f, x)

function logdensityof::DensityMeasure, x::Any)
integrand, μ_base = μ.f, μ.base

base_logval = logdensityof(μ_base, x)

T = typeof(base_logval)
U = logdensityof_rt(integrand, x)
R = promote_type(T, U)

# Don't evaluate base measure if integrand is zero or NaN
if isneginf(base_logval)
R(-Inf)
else
integrand_logval = logdensityof(integrand, x)
convert(R, integrand_logval + base_logval)::R
end
end


"""
rebase(μ, ν)
Expand Down
12 changes: 5 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,18 @@ using InverseFunctions: FunctionWithInverse
unwrap(f) = f
unwrap(f::FunctionWithInverse) = f.f


fcomp(f, g) = fchain(g, f)
fcomp(::typeof(identity), g) = g
fcomp(f, ::typeof(identity)) = f
fcomp(::typeof(identity), ::typeof(identity)) = identity

Check warning on line 171 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L169-L171

Added lines #L169 - L171 were not covered by tests

near_neg_inf(::Type{T}) where {T<:Real} = T(-1E38) # Still fits into Float32

Check warning on line 173 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L173

Added line #L173 was not covered by tests

near_neg_inf(::Type{T}) where T<:Real = T(-1E38) # Still fits into Float32

isneginf(x) = isinf(x) && x < 0
isposinf(x) = isinf(x) && x > 0
isneginf(x) = isinf(x) && x < zero(x)
isposinf(x) = isinf(x) && x > zero(x)

Check warning on line 176 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L176

Added line #L176 was not covered by tests

isapproxzero(x::T) where T<:Real = x zero(T)
isapproxzero(x::T) where {T<:Real} = x zero(T)
isapproxzero(A::AbstractArray) = all(isapproxzero, A)

Check warning on line 179 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L178-L179

Added lines #L178 - L179 were not covered by tests

isapproxone(x::T) where T<:Real = x one(T)
isapproxone(x::T) where {T<:Real} = x one(T)
isapproxone(A::AbstractArray) = all(isapproxone, A)

Check warning on line 182 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L181-L182

Added lines #L181 - L182 were not covered by tests
16 changes: 16 additions & 0 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,22 @@ end
end
end

@testset "logdensityof" begin
f1 = let A=randn(Float32, 3,3); x -> sum(A*x); end
f2 = x -> sqrt(abs(sum(x)))
f3 = x -> 2 * sum(x)
f4 = x -> sum(sqrt.(abs.(x)))
m = @inferred ∫exp(f1, ∫exp(f2, ∫exp(f3, ∫exp(f4, StdUniform()^3))))

for x in [
Float32[0.7, 0.2, 0.5],
Float32[-0.7, 0.2, 0.5],
]
@test @inferred(logdensityof(m, x)) isa Float32
@test logdensityof(m, x) f1(x) + f2(x) + f3(x) + f4(x) + logdensityof(StdUniform()^3, x)
end
end

@testset "logdensity_rel" begin
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 0.0) == Inf
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 1.0) == -Inf
Expand Down

0 comments on commit 500ed55

Please sign in to comment.