Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fit!(o, y, n) to fit multiple observations of the same value #40

Merged
merged 4 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 52 additions & 16 deletions src/OnlineStatsBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,39 @@ the type of a single observation for the provided `stat`, `fit!` will attempt to
through and `fit!` each item in `data`. Therefore, `fit!(Mean(), 1:10)` translates
roughly to:

o = Mean()
```julia
o = Mean()

for x in 1:10
fit!(o, x)
for x in 1:10
fit!(o, x)
end
```
"""
fit!(o::OnlineStat{T}, y::T) where {T} = (_fit!(o, y); return o)

function fit!(o::OnlineStat{I}, y::T) where {I, T}
T == eltype(y) && error("The input for $(name(o,false,false)) is $I. Found $T.")
for yi in y
fit!(o, yi)
end
o
end

"""
fit!(o::OnlineStat{T}, yi::T) where {T} = (_fit!(o, yi); return o)
fit!(stat::OnlineStat, y, n)

Update the "sufficient statistics" of a `stat` with multiple observations of a single value.
Unless a specialized formula is used, `fit!(Mean(), 10, 5)` is equivalent to:

```julia
o = Mean()

for _ in 1:5
fit!(o, 10)
end
```
"""
fit!(o::OnlineStat{T}, y::T, n::Integer) where {T} = (_fit!(o, y, n); return o)

"""
fit!(stat1::OnlineStat, stat2::OnlineStat)
Expand All @@ -121,23 +147,33 @@ Useful for reductions of OnlineStats using `fit!`.

# Example

julia> v = [reduce(fit!, [1, 2, 3], init=Mean()) for _ in 1:3]
3-element Vector{Mean{Float64, EqualWeight}}:
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0
```julia-repl
julia> v = [reduce(fit!, [1, 2, 3], init=Mean()) for _ in 1:3]
3-element Vector{Mean{Float64, EqualWeight}}:
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0
Mean: n=3 | value=2.0

julia> reduce(fit!, v, init=Mean())
Mean: n=9 | value=2.0
julia> reduce(fit!, v, init=Mean())
Mean: n=9 | value=2.0
```
"""
fit!(o::OnlineStat, o2::OnlineStat) = merge!(o, o2)

function fit!(o::OnlineStat{I}, y::T) where {I, T}
T == eltype(y) && error("The input for $(name(o,false,false)) is $I. Found $T.")
for yi in y
fit!(o, yi)
# general fallback for _fit!(o, y) that each stat must implement
function _fit!(o::OnlineStat{T}, y) where {T}
error("_fit!(o, y) is not implemented for $(name(o,false,false)). If you are writing " *
"a new statistic, then this must be implemented. If you are a user, then please " *
"submit a bug report.")
end

# general fallback for _fit!(o, y, n) that is optional to implement
function _fit!(o::OnlineStat{T}, y, n) where {T}
for _ in 1:n
_fit!(o, y)
end
o

return o
end

#-----------------------------------------------------------------------# utils
Expand Down
23 changes: 23 additions & 0 deletions src/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ mutable struct Counter{T} <: OnlineStat{T}
end
Counter(T = Number) = Counter{T}()
_fit!(o::Counter{T}, y) where {T} = (o.n += 1)
_fit!(o::Counter{T}, y, n) where {T} = (o.n += n)
_merge!(a::Counter, b::Counter) = (a.n += b.n)

#-----------------------------------------------------------------------# CountMap
Expand Down Expand Up @@ -122,6 +123,10 @@ function _fit!(o::CountMap{T}, xy::Pair{<:T, <:Integer}) where {T}
o.n += y
o.value[x] = get!(o.value, x, 0) + y
end
function _fit!(o::CountMap{T}, x, n) where {T}
o.n += n
o.value[x] = get!(o.value, x, 0) + n
end

_merge!(o::CountMap, o2::CountMap) = (merge!(+, o.value, o2.value); o.n += o2.n)
function probs(o::CountMap, kys = keys(o.value))
Expand Down Expand Up @@ -251,6 +256,18 @@ function _fit!(o::Extrema, y)
y == o.min && (o.nmin += 1)
y == o.max && (o.nmax += 1)
end
function _fit!(o::Extrema, y, n)
(o.n += n) == n && (o.min = o.max = y)
if y < o.min
o.min = y
o.nmin = 0
elseif y > o.max
o.max = y
o.nmax = 0
end
y == o.min && (o.nmin += n)
y == o.max && (o.nmax += n)
end
function _merge!(a::Extrema, b::Extrema)
if a.min == b.min
a.nmin += b.nmin
Expand Down Expand Up @@ -451,6 +468,10 @@ Mean(T::Type{<:Number} = Float64; weight = EqualWeight()) = Mean(zero(T), weight
function _fit!(o::Mean{T}, x) where {T}
o.μ = smooth(o.μ, x, o.weight(o.n += 1))
end
function _fit!(o::Mean{T}, y, n) where {T}
o.n += n
o.μ = smooth(o.μ, y, n / o.n)
end
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't be correct for Mean{T,W} when W is something other than EqualWeight

Copy link
Contributor Author

@adknudson adknudson Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this imply there is also a problem with _merge!(o::Mean, o2::Mean)? My thinking is along the lines that fitting y k times is equivalent to merging two means where the second one has k values and a mean of y. What is the appropriate way to deal with weights? Like this?

function _fit!(o::Mean{T}, y, n) where {T}
    o.n += n
    o.μ = smooth(o.μ, y, o.weight(o.n / n))
end

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, in testing that doesn't work. I can define a special version for Mean with equal weight, and then use the generic fallback otherwise

function _merge!(o::Mean, o2::Mean)
o.n += o2.n
o.μ = smooth(o.μ, o2.μ, o2.n / o.n)
Expand Down Expand Up @@ -524,6 +545,8 @@ Sum(T::Type = Float64) = Sum(T(0), 0)
Base.sum(o::Sum) = o.sum
_fit!(o::Sum{T}, x::Real) where {T<:AbstractFloat} = (o.sum += convert(T, x); o.n += 1)
_fit!(o::Sum{T}, x::Real) where {T<:Integer} = (o.sum += round(T, x); o.n += 1)
_fit!(o::Sum{T}, x::Real, n) where {T<:AbstractFloat} = (o.sum += convert(T, x * n); o.n += n)
_fit!(o::Sum{T}, x::Real, n) where {T<:Integer} = (o.sum += round(T, x * n); o.n += n)
_merge!(o::T, o2::T) where {T <: Sum} = (o.sum += o2.sum; o.n += o2.n; o)

#-----------------------------------------------------------------------# Variance
Expand Down
55 changes: 55 additions & 0 deletions test/test_stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ println(" > CircBuff")
fit!(b, 3:11)
@test b[end] == 7
@test b[1] == 11

# Multiple obs method
c = CircBuff(Int, 5)
fit!(c, 5, 5)
fit!(c, 10)
@test c[1] == 5
@test c[end] == 10
end

#-----------------------------------------------------------------------# Counter
Expand All @@ -50,6 +57,10 @@ println(" > Counter")
o2 = fit!(Counter(Int), 1)
@test value(merge!(o, o2)) == 11
==(mergevals(Counter(), y, y2)...)

# Multiple obs method
o3 = fit!(Counter(Int), 1, 10)
@test (value(o3)) == 10
end

#-----------------------------------------------------------------------# CountMap
Expand Down Expand Up @@ -82,6 +93,13 @@ println(" > CountMap")
# Pair method
@test ==(mergevals(CountMap(Bool), Pair.(x,z), Pair.(x2,z2); nobs_equals_length=false)...)
@test ==(mergevals(CountMap(Int), Pair.(z,z), Pair.(z2,z2); nobs_equals_length=false)...)

# Multiple obs method
c = fit!(CountMap(Bool), true, 10)
fit!(c, false, 5)
@test nobs(c) == 15
@test c[true] == 10
@test c[false] == 5
end
#-----------------------------------------------------------------------# CountMissing
println(" > CountMissing")
Expand Down Expand Up @@ -143,6 +161,15 @@ println(" > Extrema")
o = fit!(Extrema(), x)
@test o.nmin == length(x) - sum(x)
@test o.nmax == sum(x)

# Multiple obs method
o = fit!(Extrema(), y)
fit!(o, 20, 5)
fit!(o, -20, 7)
@test o.nmax == 5
@test o.nmin == 7
@test maximum(o) == 20
@test minimum(o) == -20
end
#-----------------------------------------------------------------------------# ExtremeValues
println(" > ExtremeValues")
Expand Down Expand Up @@ -227,6 +254,12 @@ println(" > Mean")
@test value(o) ≈ mean(y)
@test mean(o) ≈ mean(y)
@test ≈(mergevals(Mean(), y, y2)...)

# Multiple obs method
o = fit!(Mean(), y)
fit!(o, 1.0, 4)
v = vcat(copy(y), [1.0, 1.0, 1.0, 1.0])
@test mean(o) ≈ mean(v)
end
#-----------------------------------------------------------------------# Moments
println(" > Moments")
Expand All @@ -241,6 +274,17 @@ println(" > Moments")
for (v1,v2) in zip(mergevals(Moments(), y, y2)...)
@test v1 ≈ v2
end

# Multiple obs method
o = fit!(Moments(), y)
fit!(o, 1.0, 4)
v = vcat(copy(y), [1.0, 1.0, 1.0, 1.0])
@test value(o) ≈ [mean(v), mean(v .^ 2), mean(v .^ 3), mean(v .^ 4)]
@test mean(o) ≈ mean(v)
@test var(o) ≈ var(v)
@test std(o) ≈ std(v)
@test skewness(o) ≈ skewness(v)
@test kurtosis(o) ≈ kurtosis(v)
end

#-----------------------------------------------------------------------# Series
Expand Down Expand Up @@ -281,6 +325,9 @@ println(" > Sum")
@test ==(mergevals(Sum(Int), x, x2)...)
@test ≈(mergevals(Sum(), y, y2)...)
@test ==(mergevals(Sum(Int), z, z2)...)

# Multiple obs method
@test sum(fit!(Sum(Int), 10, 5)) == 50
end

#-----------------------------------------------------------------------------# TryCatch
Expand Down Expand Up @@ -312,6 +359,14 @@ println(" > Variance")
@test std(fit!(Variance(), [1, 2])) == sqrt(.5)
# https://github.com/joshday/OnlineStats.jl/issues/217
@test value(fit!(Variance(Float32), randn(Float32, 10))) isa Float32

# Multiple obs method
o = fit!(Variance(), y)
fit!(o, 1.0, 4)
v = vcat(copy(y), [1.0, 1.0, 1.0, 1.0])
@test mean(o) ≈ mean(v)
@test var(o) ≈ var(v)
@test std(o) ≈ std(v)
end

end # end "Test Stats"
Loading