Skip to content

Commit

Permalink
First version of an MCMC sampler (#829)
Browse files Browse the repository at this point in the history
* First version of an MCMC sampler

* More explicit check for free shared params

* Progress bar, example notebook, and first paralellization attempt

* Return burn in and changes requested by Thomas

* Set random state to None and clean up MCMC notebook

* Check for metric and correct if chi2

* Add sampling algorithm argument and simplify code
  • Loading branch information
JanWeldert authored Nov 8, 2024
1 parent e6bff6e commit 95eadf3
Show file tree
Hide file tree
Showing 3 changed files with 659 additions and 6 deletions.
135 changes: 135 additions & 0 deletions pisa/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pisa.utils.comparisons import recursiveEquality, FTYPE_PREC, ALLCLOSE_KW
from pisa.utils.log import logging, set_verbosity
from pisa.utils.fileio import to_file
from pisa.utils.random_numbers import get_random_state
from pisa.utils.stats import (METRICS_TO_MAXIMIZE, METRICS_TO_MINIMIZE,
LLH_METRICS, CHI2_METRICS, weighted_chi2,
it_got_better, is_metric_to_maximize)
Expand Down Expand Up @@ -2681,6 +2682,140 @@ def _minimizer_callback(self, xk, **unused_kwargs): # pylint: disable=unused-arg
"""
self._nit += 1

def MCMC_sampling(self, data_dist, hypo_maker, metric, nwalkers, burnin, nsteps,
return_burn_in=False, random_state=None, sampling_algorithm=None):
"""Performs MCMC sampling. Only supports serial (single CPU) execution at the
moment. See issue #830.
Parameters
----------
data_dist : Sequence of MapSets or MapSet
Data distribution to be fit. Can be an actual-, Asimov-, or pseudo-data
distribution (where the latter two are derived from simulation and so aren't
technically "data").
hypo_maker : Detectors or DistributionMaker
Creates the per-bin expectation values per map based on its param values.
Free params in the `hypo_maker` are modified by the minimizer to achieve a
"best" fit.
metric : string or iterable of strings
Metric by which to evaluate the fit. See documentation of Map.
nwalkers : int
Number of walkers
burnin : int
Number of steps in burn in phase
nSteps : int
Number of steps after burn in
return_burn_in : bool
Also return the steps of the burn in phase. Default is False.
random_state : None or type accepted by utils.random_numbers.get_random_state
Random state of the walker starting points. Default is None.
sampling_algorithm : None or emcee.moves object
Sampling algorithm used by the emcee sampler. None means to use the default which
is a Goodman & Weare “stretch move” with parallelization.
See https://emcee.readthedocs.io/en/stable/user/moves/#moves-user to learn more
about the emcee sampling algorithms.
Returns
-------
scaled_chain : numpy array
Array containing all points in the parameter space visited by each walker.
It is sorted by steps, so all the first steps of all walkers come first.
To for example get all values of the Nth parameter and the ith walker, use
scaled_chain[i::nwalkers, N].
scaled_chain_burnin : numpy array (optional)
Same as scaled_chain, but for the burn in phase.
"""
import emcee

assert 'llh' in metric or 'chi2' in metric, 'Use either a llh or chi2 metric'
if 'chi2' in metric:
warnings.warn("You are using a chi2 metric for the MCMC sampling."
"The sampler will assume that llh=0.5*chi2.")

ndim = len(hypo_maker.params.free)
bounds = np.repeat([[0,1]], ndim, axis=0)
rs = get_random_state(random_state)
p0 = rs.random(ndim * nwalkers).reshape((nwalkers, ndim))

def func(scaled_param_vals, bounds, data_dist, hypo_maker, metric):
"""Function called by the MCMC sampler. Similar to _minimizer_callable it
returns the current metric value + prior penalties.
"""
if np.any(scaled_param_vals > np.array(bounds)[:, 1]) or np.any(scaled_param_vals < np.array(bounds)[:, 0]):
return -np.inf
sign = +1 if metric in METRICS_TO_MAXIMIZE else -1
if 'llh' in metric:
N = 1
elif 'chi2' in metric:
N = 0.5

hypo_maker._set_rescaled_free_params(scaled_param_vals) # pylint: disable=protected-access
hypo_asimov_dist = hypo_maker.get_outputs(return_sum=True)
metric_val = (
N * data_dist.metric_total(expected_values=hypo_asimov_dist, metric=metric)
+ hypo_maker.params.priors_penalty(metric=metric)
)
return sign*metric_val

sampler = emcee.EnsembleSampler(
nwalkers, ndim, func,
moves=sampling_algorithm,
args=[bounds, data_dist, hypo_maker, metric]
)

if self.pprint:
sys.stdout.write('Burn in')
sys.stdout.flush()
pos, prob, state = sampler.run_mcmc(p0, burnin, progress=self.pprint)

if return_burn_in:
flatchain_burnin = sampler.flatchain
scaled_chain_burnin = np.full_like(flatchain_burnin, np.nan, dtype=FTYPE)
param_copy_burnin = ParamSet(hypo_maker.params.free)

for s, sample in enumerate(flatchain_burnin):
for dim, rescaled_val in enumerate(sample):
param = param_copy_burnin[dim]
param._rescaled_value = rescaled_val
val = param.value.m
scaled_chain_burnin[s, dim] = val

sampler.reset()
if self.pprint:
sys.stdout.write('Main sampling')
sys.stdout.flush()
sampler.run_mcmc(pos, nsteps, progress=self.pprint)

flatchain = sampler.flatchain
scaled_chain = np.full_like(flatchain, np.nan, dtype=FTYPE)
param_copy = ParamSet(hypo_maker.params.free)

for s, sample in enumerate(flatchain):
for dim, rescaled_val in enumerate(sample):
param = param_copy[dim]
param._rescaled_value = rescaled_val
val = param.value.m
scaled_chain[s, dim] = val

if return_burn_in:
return scaled_chain, scaled_chain_burnin
else:
return scaled_chain


class Analysis(BasicAnalysis):
"""Analysis class for "canonical" IceCube/DeepCore/PINGU analyses.
Expand Down
16 changes: 10 additions & 6 deletions pisa/core/detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,19 @@ def __init__(self, pipelines, label=None, set_livetime_from_data=True, profile=F
)

for sp in self.shared_params:
n = 0
N, n = 0, 0
for distribution_maker in self._distribution_makers:
if sp in distribution_maker.params.names:
N += 1
if sp in distribution_maker.params.free.names:
n += 1
if n < 2:
raise NameError('Shared param %s only a free param in less than 2 detectors.' % sp)
if N < 2:
raise NameError(f'Shared param {sp} only exists in {N} detectors.')
if n > 0 and n != N:
raise NameError(f'Shared param {sp} exists in {N} detectors but only a free param in {n} detectors.')

self.init_params()

def __repr__(self):
return self.tabulate(tablefmt="presto")

Expand Down Expand Up @@ -225,7 +229,7 @@ def shared_param_ind_list(self):
spi = []
for p_name in free_names:
if p_name in self.shared_params:
spi.append((free_names.index(p_name),self.shared_params.index(p_name)))
spi.append((free_names.index(p_name), self.shared_params.index(p_name)))
shared_param_ind_list.append(spi)
return shared_param_ind_list

Expand Down Expand Up @@ -347,7 +351,7 @@ def _set_rescaled_free_params(self, rvalues):
for j in range(len(self._distribution_makers[i].params.free) - len(spi[i])):
rp.append(rvalues.pop(0))
for j in range(len(spi[i])):
rp.insert(spi[i][j][0],sp[spi[i][j][1]])
rp.insert(spi[i][j][0], sp[spi[i][j][1]])
self._distribution_makers[i]._set_rescaled_free_params(rp)


Expand Down
514 changes: 514 additions & 0 deletions pisa_examples/MCMC_example.ipynb

Large diffs are not rendered by default.

0 comments on commit 95eadf3

Please sign in to comment.