Skip to content

Commit

Permalink
Specialize logdensityof for DensityMeasure
Browse files Browse the repository at this point in the history
Ensures proper type propagation (until future refactor of density calculation
engine).
  • Loading branch information
oschulz committed Nov 4, 2024
1 parent 19e8a27 commit e5b8af3
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 6 deletions.
7 changes: 7 additions & 0 deletions src/density-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@ To compute a log-density relative to a specific base-measure, see
_checksupport(insupport(μ, x), result)
end

@inline function logdensityof_rt(::T, ::U) where {T,U}
Core.Compiler.return_type(logdensityof, Tuple{T,U})
end

_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v) = logdensityof_rt(target, v), _logdensityof_rt_pullback

Check warning on line 41 in src/density-core.jl

View check run for this annotation

Codecov / codecov/patch

src/density-core.jl#L40-L41

Added lines #L40 - L41 were not covered by tests

_checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf))

import ChainRulesCore
Expand Down
19 changes: 19 additions & 0 deletions src/density.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,25 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x)

density_def::DensityMeasure, x) = densityof.f, x)

function logdensityof::DensityMeasure, x::Any)
integrand, μ_base = μ.f, μ.base

base_logval = logdensityof(μ_base, x)

T = typeof(base_logval)
U = logdensityof_rt(integrand, x)
R = promote_type(T, U)

# Don't evaluate base measure if integrand is zero or NaN
if isneginf(base_logval)
R(-Inf)
else
integrand_logval = logdensityof(integrand, x)
convert(R, integrand_logval + base_logval)::R
end
end


"""
rebase(μ, ν)
Expand Down
16 changes: 10 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,24 @@ using InverseFunctions: FunctionWithInverse
unwrap(f) = f
unwrap(f::FunctionWithInverse) = f.f


fcomp(f, g) = fchain(g, f)
fcomp(::typeof(identity), g) = g
fcomp(f, ::typeof(identity)) = f
fcomp(::typeof(identity), ::typeof(identity)) = identity

Check warning on line 171 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L169-L171

Added lines #L169 - L171 were not covered by tests

near_neg_inf(::Type{T}) where {T<:Real} = T(-1E38) # Still fits into Float32

Check warning on line 173 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L173

Added line #L173 was not covered by tests

isneginf(x) = isinf(x) && x < zero(x)
isposinf(x) = isinf(x) && x > zero(x)

Check warning on line 176 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L176

Added line #L176 was not covered by tests

near_neg_inf(::Type{T}) where T<:Real = T(-1E38) # Still fits into Float32
_isneginf_pullback(::Any) = (NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(isneginf), x) = isneginf(x), _logdensityof_rt_pullback

Check warning on line 179 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L178-L179

Added lines #L178 - L179 were not covered by tests

isneginf(x) = isinf(x) && x < 0
isposinf(x) = isinf(x) && x > 0
_isposinf_pullback(::Any) = (NoTangent(), ZeroTangent())
ChainRulesCore.rrule(::typeof(isposinf), x) = isposinf(x), _isposinf_pullback

Check warning on line 182 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L181-L182

Added lines #L181 - L182 were not covered by tests

isapproxzero(x::T) where T<:Real = x zero(T)
isapproxzero(x::T) where {T<:Real} = x zero(T)
isapproxzero(A::AbstractArray) = all(isapproxzero, A)

Check warning on line 185 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L184-L185

Added lines #L184 - L185 were not covered by tests

isapproxone(x::T) where T<:Real = x one(T)
isapproxone(x::T) where {T<:Real} = x one(T)
isapproxone(A::AbstractArray) = all(isapproxone, A)

Check warning on line 188 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L187-L188

Added lines #L187 - L188 were not covered by tests
16 changes: 16 additions & 0 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,22 @@ end
end
end

@testset "logdensityof" begin
f1 = let A=randn(Float32, 3,3); x -> sum(A*x); end
f2 = x -> sqrt(abs(sum(x)))
f3 = x -> 2 * sum(x)
f4 = x -> sum(sqrt.(abs.(x)))
m = @inferred ∫exp(f1, ∫exp(f2, ∫exp(f3, ∫exp(f4, StdUniform()^3))))

for x in [
Float32[0.7, 0.2, 0.5],
Float32[-0.7, 0.2, 0.5],
]
@test @inferred(logdensityof(m, x)) isa Float32
@test logdensityof(m, x) f1(x) + f2(x) + f3(x) + f4(x) + logdensityof(StdUniform()^3, x)
end
end

@testset "logdensity_rel" begin
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 0.0) == Inf
@test logdensity_rel(Dirac(0.0) + Lebesgue(), Dirac(1.0), 1.0) == -Inf
Expand Down

0 comments on commit e5b8af3

Please sign in to comment.