Skip to content

Commit

Permalink
Dev (#213)
Browse files Browse the repository at this point in the history
* Dirichlet(k::Integer, α) = Dirichlet(Fill(α, k))

* export TransformVariables as TV

* drop redundant import

* 0.0 => zero(Float64)

* drop outdated Dists.logpdf

* update StudentT

* drop redundant import

* update Uniform

* bump MeasureBase version

* reworking beta

* small update to StudentT

* basemeasure for discrete Distributions

* using LogExpFunctions => import LogExpFunctions

* quoteof(::Chain)

* prettyprinting and chain-mucking

* Some refactoring for Markov chains

* import MeasureBase: ≪

* version bound for PrettyPrinting

* copy(rng) might change its type (e.g. GLOBAL_RNG)

* tests pass

* cleaning up

* more cleanup

* big update

* get tests passing

* formatting

* oops typo

* move affine to MeasureTheory

* updating

* Val => StaticSymbol

* more fixes

* fix fix fix

* more logdesnity => logdensity_def

* more logdesnity fixes

* debugging

* formatting

* bugfixes

* working on tests

* updates

* working on tests

* tests passing!

* refactor

* working on tests

* drop static weight for now

* fix sampling from ProductMeasure{<:Base.Generator}

* tests passing!!

* more stuff

* constructor => constructorof

* constructor =? construtorof

* updates

* working on tests

* fix Dirichlet

* update Bernoulli

* working on tests

* bugfixes for RealizedSamples

* tests passing!!

* tighten down inference

* as(::PowerMeasure)

* drop type-level stuff

* using InverseFunctions.jl

* update license

* affero

* copyright

* update CI to 1.6

* xform => TV.as

* oops missed a conflict

* fix merge corruption

* typo

* fix license

* Update README.md

* merge

* enumerate instead of zip

* bugfix

* inline rand

* drop `static` from `insupport` results

* update proxies

* Move ConditionalMeasure to MeasureBase

* IfElse.ifelse(p::Bernoulli, t, f)

* IfElseMeasure

* update some base measures

* test broken :(

* fix some redundancies

* instance_type => Core.Typeof

* update testvalue for Bernoulli and Binomial

* un-break broken test (now passing)

* Fall-back `For` method for when inference fails

* drop extra spaces

* more whitespace

* bump MeasureBase dependency version

* add newline

* tidy up

* ifelse tests

* OEF newline

* avoid type piracy

* add Julia 1.7 to CI

* make Julia 1.6 happy

* approx instead of ==

* Require at least Julia 1.6

* Try Sebastian's idea test_measures ::Any[]

* Another Any[]

* Drop Likelihood test

* drop 1.7 CI (seems buggy?)

* bump version

* export likelihood

* Snedecor's F

* Gamma distribution

* more gamma stuff

* Beroulli()

* inverse Gaussian

* Getting modifed GLM.jl tests to pass

* drop pdf and logpdf

* Poisson bugfix

* update Normal(μ, σ²)

* Gamma(μ, ϕ) for GLMs

* updates for GLM support

* start on truncated

* update parameterized measures

* drop FactoredBase

* drop old LazyArrays dependency

* insupport(::Distribution)

* Left out"Dists."

* don't export `ifelse` (#192)

* Kleisli => TransitionKernel

* depend on StatsBase

* tests passing

* bump MeasureBase version

* work on truncated and censored

* improve func_string

* Simplify logdensity_def(::For, x)

* Move truncated and censored updates to separate branches

* newline

* comment out in-progress stuff

* newline

* bump version

* update formatting spec

* more formatting

* tweedie docs

* drop redundant exports

* update exports

* omega => lambda

* drop SequentialEx

* get tests passing

* add kernel tests

* gitignore

* better `Pretty.tile` for Affine and AffineTransforms

* formatting

* kleisli => kernel

* update tile(::For)

* update Compat version

* bump MB version

* update gamma

* Let's come back to InverseGaussian

* CI on 1.7

* update IfElse

* formatting

* update product

* @kwstruct Bernoulli(logitp)

* Base.size(r::RealizedSamples)

* cdf(::Affine, x)

* working on Aqua and JET fixes

* formatting

* So many `rand` methods, do we really need all of these?

* bump version

* bugfix

* bugfix

* loosen constraint
  • Loading branch information
cscherrer authored Jun 23, 2022
1 parent 8055b52 commit 84c6063
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 30 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MeasureTheory"
uuid = "eadaa1a4-d27c-401d-8699-e962e1bbc33b"
authors = ["Chad Scherrer <[email protected]> and contributors"]
version = "0.16.3"
version = "0.16.4"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -18,6 +18,7 @@ Infinities = "e1ba4f0e-776d-440f-acd9-e1d2e9742647"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
KeywordCalls = "4d827475-d3e4-43d6-abe3-9688362ede9f"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Expand Down Expand Up @@ -53,11 +54,12 @@ IfElse = "0.1"
Infinities = "0.1"
InverseFunctions = "0.1"
KeywordCalls = "0.2"
LazyArrays = "0.22"
LogExpFunctions = "0.3.3"
MLStyle = "0.4"
MacroTools = "0.5"
MappedArrays = "0.4"
MeasureBase = "0.10"
MeasureBase = "0.12"
NamedTupleTools = "0.13, 0.14"
NestedTuples = "0.3"
PositiveFactorizations = "0.2"
Expand Down
17 changes: 14 additions & 3 deletions src/combinators/affine.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
export Affine, AffineTransform
using LinearAlgebra
import Base

const AFFINEPARS = [
(, )
Expand Down Expand Up @@ -43,8 +44,8 @@ Base.size(f::AffineTransform{(:λ,)}) = size(f.λ)

LinearAlgebra.rank(f::AffineTransform{(:σ,)}) = rank(f.σ)
LinearAlgebra.rank(f::AffineTransform{(:λ,)}) = rank(f.λ)
LinearAlgebra.rank(f::AffineTransform{(:μ,:σ,)}) = rank(f.σ)
LinearAlgebra.rank(f::AffineTransform{(:μ,:λ,)}) = rank(f.λ)
LinearAlgebra.rank(f::AffineTransform{(:μ,)}) = rank(f.σ)
LinearAlgebra.rank(f::AffineTransform{(:μ,)}) = rank(f.λ)

function Base.size(f::AffineTransform{(:μ,)})
(n,) = size(f.μ)
Expand Down Expand Up @@ -181,7 +182,7 @@ Affine(nt::NamedTuple, μ::AbstractMeasure) = affine(nt, μ)

Affine(nt::NamedTuple) = affine(nt)

parent(d::Affine) = getfield(d, :parent)
Base.parent(d::Affine) = getfield(d, :parent)

function params::Affine)
nt1 = getfield(getfield(μ, :f), :par)
Expand Down Expand Up @@ -262,6 +263,12 @@ end
weightedmeasure(-logjac(d), OrthoLebesgue(params(d)))
end

@inline function basemeasure(
d::MeasureTheory.Affine{N,L,Tuple{A}},
) where {N,L<:MeasureBase.Lebesgue,A<:AbstractArray}
weightedmeasure(-logjac(d), OrthoLebesgue(params(d)))
end

@inline function basemeasure(
d::Affine{N,M,Tuple{A1,A2}},
) where {N,M,A1<:AbstractArray,A2<:AbstractArray}
Expand Down Expand Up @@ -328,3 +335,7 @@ end
@inline function insupport(d::Affine, x)
insupport(d.parent, inverse(d.f)(x))
end

@inline function Distributions.cdf(d::Affine, x)
cdf(parent(d), inverse(d.f)(x))
end
7 changes: 5 additions & 2 deletions src/combinators/exponential-families.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export ExponentialFamily
using LazyArrays

@concrete terse struct ExponentialFamily <: AbstractTransitionKernel
support_contains
Expand All @@ -16,10 +17,10 @@ function ExponentialFamily(support_contains, base, mdim, pdim, t, a)
return ExponentialFamily(support_contains, base, mdim, pdim, t, I, a)
end

function MeasureBase.powermeasure(fam::ExponentialFamily, dims::NTuple{N,I}) where {N,I}
function MeasureBase.powermeasure(fam::ExponentialFamily, dims::NTuple)
support_contains(x) = all(xj -> fam.support_contains(xj), x)
t = Tuple((y -> f.(y) for f in fam.t))
a(η) = BroadcastArray(fam.a, η)
a(η) = LazyArrays.BroadcastArray(fam.a, η)
p = prod(dims)
ExponentialFamily(
support_contains,
Expand All @@ -32,6 +33,8 @@ function MeasureBase.powermeasure(fam::ExponentialFamily, dims::NTuple{N,I}) whe
)
end

powermeasure(fam::ExponentialFamily, ::Tuple{}) = fam

@concrete terse struct ExpFamMeasure <: AbstractMeasure
fam
η # instantiated to a value
Expand Down
18 changes: 17 additions & 1 deletion src/distproxy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,23 @@ for m in keys(PROXIES)
@eval begin
import $m: $f
export $f
$m.$f(d::AbstractMeasure, args...) = $m.$f(MeasureTheory.proxy(d), args...)
end
end
end

entropy(m::AbstractMeasure, b::Real) = entropy(proxy(m), b)
mean(m::AbstractMeasure) = mean(proxy(m))
std(m::AbstractMeasure) = std(proxy(m))
var(m::AbstractMeasure) = var(proxy(m))
quantile(m::AbstractMeasure, q) = quantile(proxy(m), q)

for f in [
:cdf
:ccdf
:logcdf
:logccdf
]
@eval begin
$f(d::AbstractMeasure, args...) = $f(MeasureTheory.proxy(d), args...)
end
end
2 changes: 1 addition & 1 deletion src/parameterized/inverse-gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ function logdensity_def(d::InverseGaussian{(:μ, :ϕ)}, x)
end

function basemeasure(d::InverseGaussian{(:μ, :ϕ)})
= static(-0.5) * (static(log2π) + log(d.ϕ))
= static(-0.5) * (static(float(log2π)) + log(d.ϕ))
weightedmeasure(ℓ, Lebesgue())
end
34 changes: 21 additions & 13 deletions src/parameterized/mvnormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ export MvNormal

as(d::MvNormal{(:μ,)}) = as(Array, length(d.μ))

as(d::MvNormal{(:Σ,),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Σ, 1))
as(d::MvNormal{(:Λ,),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Λ, 1))
as(d::MvNormal{(:μ, :Σ),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Σ, 1))
as(d::MvNormal{(:μ, :Λ),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Λ, 1))
as(d::MvNormal{(:Σ,),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Σ, 1))
as(d::MvNormal{(:Λ,),Tuple{C}}) where {C<:Cholesky} = as(Array, size(d.Λ, 1))
as(d::MvNormal{(:μ, :Σ),<:Tuple{T,C}}) where {T,C<:Cholesky} = as(Array, size(d.Σ, 1))
as(d::MvNormal{(:μ, :Λ),<:Tuple{T,C}}) where {T,C<:Cholesky} = as(Array, size(d.Λ, 1))

function as(d::MvNormal{(:σ,),Tuple{M}}) where {M<:Triangular}
σ = d.σ
if @inbounds all(i -> σ[i] > 0, diagind(σ))
if @inbounds all(i -> σ[i] 0, diagind(σ))
return as(Array, size(σ, 1))
else
@error "Not implemented yet"
Expand All @@ -49,7 +49,7 @@ for N in setdiff(AFFINEPARS, [(:μ,)])
@eval begin
function as(d::MvNormal{$N})
p = proxy(d)
if rank(getfield(p,:f)) == only(supportdim(d))
if rank(getfield(p, :f)) == only(supportdim(d))
return as(Array, supportdim(d))
else
@error "Not yet implemented"
Expand All @@ -61,13 +61,13 @@ end
supportdim(d::MvNormal) = supportdim(params(d))

supportdim(nt::NamedTuple{(:Σ,)}) = size(nt.Σ, 1)
supportdim(nt::NamedTuple{(:μ,:Σ)}) = size(nt.Σ, 1)
supportdim(nt::NamedTuple{(:μ, :Σ)}) = size(nt.Σ, 1)
supportdim(nt::NamedTuple{(:Λ,)}) = size(nt.Λ, 1)
supportdim(nt::NamedTuple{(:μ,:Λ)}) = size(nt.Λ, 1)
supportdim(nt::NamedTuple{(:μ, :Λ)}) = size(nt.Λ, 1)

@useproxy MvNormal

for N in [(,), (,), (,), (,)]
for N in [(,), (, ), (,), (, )]
@eval basemeasure_depth(d::MvNormal{$N}) = static(2)
end

Expand All @@ -78,7 +78,15 @@ rand(rng::AbstractRNG, ::Type{T}, d::MvNormal) where {T} = rand(rng, T, proxy(d)
insupport(d::MvNormal, x) = insupport(proxy(d), x)

# Note: (C::Cholesky).L may or may not make a copy, depending on C.uplo, which is not included in the type
@inline proxy(d::MvNormal{(:Σ,),Tuple{C}}) where {C<:Cholesky} = affine((σ = d.Σ.L,), Normal()^supportdim(d))
@inline proxy(d::MvNormal{(:Λ,),Tuple{C}}) where {C<:Cholesky} = affine((λ = d.Λ.L,), Normal()^supportdim(d))
@inline proxy(d::MvNormal{(:μ, :Σ),Tuple{C}}) where {C<:Cholesky} = affine((μ = d.μ, σ = d.Σ.L), Normal()^supportdim(d))
@inline proxy(d::MvNormal{(:μ, :Λ),Tuple{C}}) where {C<:Cholesky} = affine((μ = d.μ, λ = d.Λ.L), Normal()^supportdim(d))
@inline function proxy(d::MvNormal{(:Σ,),Tuple{C}}) where {C<:Cholesky}
affine((σ = d.Σ.L,), Normal()^supportdim(d))
end
@inline function proxy(d::MvNormal{(:Λ,),Tuple{C}}) where {C<:Cholesky}
affine((λ = d.Λ.L,), Normal()^supportdim(d))
end
@inline function proxy(d::MvNormal{(:μ, :Σ),Tuple{T,C}}) where {T,C<:Cholesky}
affine((μ = d.μ, σ = d.Σ.L), Normal()^supportdim(d))
end
@inline function proxy(d::MvNormal{(:μ, :Λ),Tuple{T,C}}) where {T,C<:Cholesky}
affine((μ = d.μ, λ = d.Λ.L), Normal()^supportdim(d))
end
2 changes: 1 addition & 1 deletion src/parameterized/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ HalfNormal(σ) = HalfNormal((σ = σ,))
end

@inline function basemeasure(d::Normal{(:σ²,)})
= static(-0.5) * (static(log2π) + log(d.σ²))
= static(-0.5) * (static(float(log2π)) + log(d.σ²))
weightedmeasure(ℓ, Lebesgue())
end

Expand Down
57 changes: 54 additions & 3 deletions src/resettable-rng.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,57 @@ function Random.Sampler(
return Random.Sampler(r.rng, s, r)
end

function Base.rand(r::ResettableRNG, sp::Random.Sampler)
rand(r.rng, sp)
end
# UIntBitsTypes = [UInt128, UInt16, UInt32, UInt64, UInt8]
# IntBitsTypes = [Int128, Int16, Int32, Int64, Int8, UInt128, UInt16, UInt32, UInt64, UInt8]
# FloatBitsTypes = [Float16, Float32, Float64]

# for I in IntBitsTypes
# for T in [
# Random.SamplerTrivial{Random.UInt104Raw{I}}
# Random.SamplerTrivial{Random.UInt10Raw{I}}
# ]
# @eval begin
# function Base.rand(r::ResettableRNG, sp::$T)
# rand(r.rng, sp)
# end
# end
# end
# end

# for U in UIntBitsTypes
# for I in IntBitsTypes
# for T in [
# Random.SamplerRangeInt{T,U} where {T<:Union{IntBitsTypes...}}
# Random.SamplerRangeFast{U,I}
# ]
# @eval begin
# function Base.rand(r::ResettableRNG, sp::$T)
# rand(r.rng, sp)
# end
# end
# end
# end
# end

# for T in [
# Random.Sampler
# Random.SamplerBigInt
# Random.SamplerTag{<:Set,<:Random.Sampler}
# # Random.SamplerTrivial{Random.CloseOpen01{T}} where {T<:FloatBitsTypes}
# # Random.SamplerTrivial{Random.UInt23Raw{UInt32}}
# Random.UniformT
# Random.SamplerSimple{T,S,E} where {E,S,T<:Tuple}
# Random.SamplerType{T} where {T<:AbstractChar}
# Random.SamplerTrivial{Tuple{A}} where {A}
# Random.SamplerSimple{Tuple{A,B,C},S,E} where {E,S,A,B,C}
# Random.SamplerSimple{<:AbstractArray,<:Random.Sampler}
# Random.Masked
# Random.SamplerSimple{BitSet,<:Random.Sampler}
# Random.SamplerTrivial{<:Random.UniformBits{T},E} where {E,T}
# ]
# @eval begin
# function Base.rand(r::ResettableRNG, sp::$T)
# rand(r.rng, sp)
# end
# end
# end
19 changes: 15 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,18 @@ using MeasureBase.Interface
using MeasureTheory: kernel
using Aqua
using IfElse
Aqua.test_all(MeasureTheory; ambiguities = false, unbound_args = false)

# Aqua._test_ambiguities(
# Aqua.aspkgids(MeasureTheory);
# exclude = [Random.AbstractRNG],
# # packages::Vector{PkgId};
# # color::Union{Bool, Nothing} = nothing,
# # exclude::AbstractArray = [],
# # # Options to be passed to `Test.detect_ambiguities`:
# # detect_ambiguities_options...,
# )

Aqua.test_all(MeasureBase; ambiguities = false)

function draw2(μ)
x = rand(μ)
Expand All @@ -23,8 +34,8 @@ function draw2(μ)
return (x, y)
end

x = randn(10,3)
Σ = cholesky(x'*x)
x = randn(10, 3)
Σ = cholesky(x' * x)
Λ = cholesky(inv(Σ))
σ = MeasureTheory.getL(Σ)
λ = MeasureTheory.getL(Λ)
Expand Down Expand Up @@ -611,7 +622,7 @@ end
@testset "IfElseMeasure" begin
p = rand()
x = randn()

@test let
a = logdensityof(IfElse.ifelse(Bernoulli(p), Normal(), Normal()), x)
b = logdensityof(Normal(), x)
Expand Down

2 comments on commit 84c6063

@cscherrer
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62951

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.16.4 -m "<description of version>" 84c6063a1180c2358d110aa598c2ddcac8653725
git push origin v0.16.4

Please sign in to comment.