Skip to content

Commit

Permalink
averaging
Browse files Browse the repository at this point in the history
  • Loading branch information
grzanka committed Mar 19, 2024
1 parent 5024e18 commit a89b352
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 15 deletions.
50 changes: 39 additions & 11 deletions pymchelper/averaging.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,53 @@
from dataclasses import dataclass
from typing import Union, Optional
from numpy.typing import ArrayLike

# resource https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
# also https://justinwillmert.com/posts/2022/notes-on-calculating-online-statistics/
# TODO add tests

# weight_fraction = weight_factor / total_weight
# omega_n = w_n / W_n
# also


@dataclass
class WeightedStats:
mean: float = 0
"""
Class for calculating weighted mean of a sequence of numbers.
Accoring to https://justinwillmert.com/posts/2022/notes-on-calculating-online-statistics/
Heavily based on Welford's algorithm [1]
[1] Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
Technometrics. 4 (3): 419–420.
See also:
[2] Schubert, Erich, and Michael Gertz.
"Numerically stable parallel computation of (co-) variance."
Proceedings of the 30th international conference on scientific and statistical database management. 2018.
"""
mean: Union[float, ArrayLike] = float('nan')
accumulator_S: Union[float, ArrayLike] = float('nan')
temp: Union[float, ArrayLike] = float('nan')
total_weight: float = 0
total_weight_squared: float = 0

def update(self, value: float, weight: float = 1.0):
def update(self, value: Union[float, ArrayLike], weight: float = 1.0):
if weight < 0:
raise ValueError("Weight must be non-negative")

# first pass initialization
if self.total_weight == 0:
self.mean = value * 0
self.accumulator_S = value * 0

# W_n = W_{n-1} + w_n
self.total_weight += weight
self.total_weight_squared += weight**2

mean_old = self.mean
# # mu_n = (1 - w_n / W_n) * mu_{n-1} + (w_n / W_n) * x_n
# first_part = (1 - weight / self.total_weight) * self.mean
# second_part = (weight / self.total_weight) * value
self.mean += (weight / self.total_weight) * (value - mean_old)

self.accumulator_S += weight * (value - self.mean) * (value - mean_old)

def variance_population(self):
return self.accumulator_S / self.total_weight

# mu_n = (1 - w_n / W_n) * mu_{n-1} + (w_n / W_n) * x_n
first_part = (1 - weight / self.total_weight) * self.mean
second_part = (weight / self.total_weight) * value
self.mean = first_part + second_part
def variance_sample(self):
return self.accumulator_S / (self.total_weight - 1)
106 changes: 102 additions & 4 deletions tests/test_weighted_stats.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import pytest
from pymchelper.averaging import WeightedStats # Make sure to import your class correctly
import numpy as np
from pymchelper.averaging import WeightedStats


def test_initial_state():
ws = WeightedStats()
assert ws.mean == 0
# check if ws.mean is nan
assert np.isnan(ws.mean)
assert ws.total_weight == 0


Expand All @@ -31,11 +33,107 @@ def test_multiple_updates():

def test_zero_weight():
ws = WeightedStats()
with pytest.raises(Exception): # Replace Exception with the specific exception you expect
with pytest.raises(Exception):
ws.update(value=10, weight=0)


def test_negative_weight():
ws = WeightedStats()
with pytest.raises(Exception): # Replace Exception with the specific exception you expect
with pytest.raises(Exception):
ws.update(value=10, weight=-1)


def test_update_with_1d_array():
ws = WeightedStats()
values = np.array([10, 20, 30])
weights = np.array([2, 3, 5])
total_weight = weights.sum()
weighted_sum = np.dot(values, weights)
expected_mean = weighted_sum / total_weight

for value, weight in zip(values, weights):
ws.update(value, weight)

assert ws.total_weight == total_weight
assert pytest.approx(ws.mean, 0.001) == expected_mean


def test_update_with_flattened_array():
ws = WeightedStats()
values = np.array([[10, 20], [30, 40]]).flatten()
weights = np.array([[2, 3], [4, 1]]).flatten()
total_weight = weights.sum()
weighted_sum = np.dot(values, weights)
expected_mean = weighted_sum / total_weight

for value, weight in zip(values, weights):
ws.update(value, weight)

assert ws.total_weight == total_weight
assert pytest.approx(ws.mean, 0.001) == expected_mean


def compute_expected_variance(values, weights, total_weight, is_sample=False):
"""Utility function to compute the expected variance."""
weighted_mean = np.average(values, weights=weights)
variance = np.sum(weights * (values - weighted_mean)**2)
if is_sample:
variance /= (total_weight - 1)
else:
variance /= total_weight
return variance


def test_variance_population_single_update():
ws = WeightedStats()
ws.update(value=10, weight=2)
# Variance should be 0 for a single value
assert ws.variance_population() == 0


def test_variance_population_multiple_updates():
ws = WeightedStats()
values = np.array([10, 20, 30])
weights = np.array([2, 3, 5])
total_weight = weights.sum()

for value, weight in zip(values, weights):
ws.update(value, weight)

expected_variance = compute_expected_variance(values, weights, total_weight)
assert pytest.approx(ws.variance_population(), 0.001) == expected_variance


def test_variance_sample_multiple_updates():
ws = WeightedStats()
values = np.array([10, 20, 30])
weights = np.array([2, 3, 5])
total_weight = weights.sum()

for value, weight in zip(values, weights):
ws.update(value, weight)

# Sample variance calculation should only be used when there are at least two samples
if total_weight > 1:
expected_variance = compute_expected_variance(values, weights, total_weight, is_sample=True)
assert pytest.approx(ws.variance_sample(), 0.001) == expected_variance
else:
with pytest.raises(ZeroDivisionError):
_ = ws.variance_sample()


def test_variance_with_1d_array():
ws = WeightedStats()
values = np.array([10, 20, 30])
weights = np.array([2, 3, 5])
total_weight = weights.sum()

for value, weight in zip(values, weights):
ws.update(value, weight)

expected_variance_population = compute_expected_variance(values, weights, total_weight)
assert pytest.approx(ws.variance_population(), 0.001) == expected_variance_population

if total_weight > 1:
expected_variance_sample = compute_expected_variance(values, weights, total_weight, is_sample=True)
assert pytest.approx(ws.variance_sample(), 0.001) == expected_variance_sample

0 comments on commit a89b352

Please sign in to comment.