diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index f802eaa..0dced07 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -7,6 +7,7 @@ from collections.abc import Sequence from typing import TYPE_CHECKING +import dill import jax import jax.numpy as jnp import nautilus @@ -557,6 +558,8 @@ def nuts( init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, + filepath: str | None = None, + resume: bool = False, **nuts_kwargs: dict, ) -> PosteriorResult: """Run the No-U-Turn Sampler of :mod:`numpyro`. @@ -581,6 +584,12 @@ def nuts( The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bar during sampling. The default is True. + filepath : str, optional + Path to the file where last state are saved. Must have `.pkl` extension. + If None, no file are written. Default is None. + resume : bool, optional + If True, read the last state file from a previous run, and then, + sampling will skip the warmup adaptation phase. Default is True. **nuts_kwargs : dict Extra parameters passed to :class:`numpyro.infer.NUTS`. @@ -631,10 +640,46 @@ def nuts( progress_bar=progress, ) - sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - extra_fields=('energy', 'num_steps'), - ) + if resume: + try: + with open(filepath, 'rb') as f: + last_state = dill.load(f) + print('Load last state file...') + sampler.post_warmup_state = last_state + print('Sampling...') + sampler.run( + rng_key=sampler.post_warmup_state.rng_key, + extra_fields=('energy', 'num_steps'), + ) + except: + print('Failed to load last state file.') + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + extra_fields=('energy', 'num_steps'), + ) + + elif warmup > 0: + print('Warming up') + sampler.warmup( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + extra_fields=('energy', 'num_steps'), + ) + print('Sampling...') + sampler.run( + rng_key=sampler.post_warmup_state.rng_key, + extra_fields=('energy', 'num_steps'), + ) + else: + print('Sampling...') + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + extra_fields=('energy', 'num_steps'), + ) + + if filepath is not None: + with open(filepath, 'wb') as f: + dill.dump(sampler.last_state, f) + return PosteriorResult(sampler, self._helper, self) def jaxns( @@ -724,7 +769,7 @@ def jaxns( termination_kwargs=termination_kwargs, ) - print('Start nested sampling...') + print('Start nested') t0 = time.time() sampler.run(rng_key=jax.random.PRNGKey(self._helper.seed['mcmc'])) print(f'Sampling cost {time.time() - t0:.2f} s') @@ -825,7 +870,7 @@ def transform_(samples): sampler._transform_back = transform_ if read_file is None: - print('Start nested sampling...') + print('Start nested') t0 = time.time() sampler.run(min_ess=int(ess), **termination_kwargs) print(f'Sampling cost {time.time() - t0:.2f} s') @@ -934,7 +979,7 @@ def transform_(samples): termination_kwargs['discard_exploration'] = True termination_kwargs.setdefault('verbose', True) - print('Start nested sampling...') + print('Start nested') t0 = time.time() success = sampler.run(n_eff=int(ess), **termination_kwargs) if success: @@ -957,6 +1002,8 @@ def aies( chain_method: str = 'vectorized', n_parallel: int | None = None, progress: bool = True, + filepath: str | None = None, + resume: bool = False, moves: dict | None = None, **aies_kwargs: dict, ) -> PosteriorResult: @@ -991,7 +1038,14 @@ def aies( ``"parallel"``. Defaults to ``jax.local_device_count()``. progress : bool, optional Whether to show progress bar during sampling. The default is True. - If `chain_method` is set to ``'parallel'``, this is always False. + If `chain_method` is set to ``'parallel'``, this is + always False after warmup. + filepath : str, optional + Path to the file where last state are saved. Must have `.pkl` extension. + If None, no file are written. Default is None. + resume : bool, optional + If True, read the last state file from a previous run, and then, + sampling will skip the warmup adaptation phase. Default is True. moves : dict, optional Moves for the sampler. **aies_kwargs : dict @@ -1036,52 +1090,91 @@ def aies( else: aies_kwargs['moves'] = moves - if chain_method == 'parallel': - aies_kernel = AIES(**aies_kwargs) - - def do_mcmc(rng_key): - mcmc = MCMC( - aies_kernel, - num_warmup=warmup, - num_samples=steps, - num_chains=chains, - chain_method='vectorized', - progress_bar=False, + rng_key = jax.random.PRNGKey(self._helper.seed['mcmc']) + + sampler = MCMC( + AIES(**aies_kwargs), + num_warmup=warmup, + num_samples=steps, + num_chains=chains, + chain_method='vectorized', + progress_bar=progress, + ) + + if resume: + try: + with open(filepath, 'rb') as f: + last_state = dill.load(f) + print('Load last state file...') + sampler.post_warmup_state = last_state + print('Sampling...') + run_ensemble( + sampler, + AIES(**aies_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, ) - mcmc.run( + except: + print('Failed to load last state file.') + if warmup > 0: + print('Warming up') + sampler.warmup( + rng_key=rng_key, + init_params=init, + ) + print('Sampling...') + run_ensemble( + sampler, + AIES(**aies_kwargs), rng_key, - init_params=init, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, ) - return mcmc.get_samples(group_by_chain=True) - - rng_keys = jax.random.split( - jax.random.PRNGKey(self._helper.seed['mcmc']), - get_parallel_number(n_parallel), + elif warmup > 0: + print('Warming up') + sampler.warmup( + rng_key=rng_key, + init_params=init, ) - traces = jax.pmap(do_mcmc)(rng_keys) - trace = {k: np.concatenate(v) for k, v in traces.items()} - - sampler = MCMC( - aies_kernel, - num_warmup=warmup, - num_samples=steps, + print('Sampling...') + run_ensemble( + sampler, + AIES(**aies_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, ) - sampler._states = {sampler._sample_field: trace} - else: - sampler = MCMC( + print('Sampling...') + run_ensemble( + sampler, AIES(**aies_kwargs), - num_warmup=warmup, - num_samples=steps, - num_chains=chains, - chain_method=chain_method, - progress_bar=progress, + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, ) - sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - init_params=init, - ) + if filepath is not None: + with open(filepath, 'wb') as f: + dill.dump(sampler.last_state, f) + return PosteriorResult(sampler, self._helper, self) def ess( @@ -1094,6 +1187,8 @@ def ess( n_parallel: int | None = None, progress: bool = True, moves: dict | None = None, + filepath: str | None = None, + resume: bool = False, **ess_kwargs: dict, ) -> PosteriorResult: """Ensemble Slice Sampling (ESS) of :mod:`numpyro`. @@ -1127,7 +1222,14 @@ def ess( ``"parallel"``. Defaults to ``jax.local_device_count()``. progress : bool, optional Whether to show progress bar during sampling. The default is True. - If `chain_method` is set to ``'parallel'``, this is always False. + If `chain_method` is set to ``'parallel'``, this is + always False after warmup. + filepath : str, optional + Path to the file where last state are saved. Must have `.pkl` extension. + If None, no file are written. Default is None. + resume : bool, optional + If True, read the last state file from a previous run, and then, + sampling will skip the warmup adaptation phase. Default is True. moves : dict, optional Moves for the sampler. **ess_kwargs : dict @@ -1169,52 +1271,91 @@ def ess( else: ess_kwargs['moves'] = moves - if chain_method == 'parallel': - ess_kernel = ESS(**ess_kwargs) - - def do_mcmc(rng_key): - mcmc = MCMC( - ess_kernel, - num_warmup=warmup, - num_samples=steps, - num_chains=chains, - chain_method='vectorized', - progress_bar=False, + rng_key = jax.random.PRNGKey(self._helper.seed['mcmc']) + + sampler = MCMC( + ESS(**ess_kwargs), + num_warmup=warmup, + num_samples=steps, + num_chains=chains, + chain_method='vectorized', + progress_bar=progress, + ) + + if resume: + try: + with open(filepath, 'rb') as f: + last_state = dill.load(f) + print('Load last state file...') + sampler.post_warmup_state = last_state + print('Sampling...') + run_ensemble( + sampler, + ESS(**ess_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, ) - mcmc.run( + except: + print('Failed to load last state file.') + if warmup > 0: + print('Warming up') + sampler.warmup( + rng_key=rng_key, + init_params=init, + ) + print('Sampling...') + run_ensemble( + sampler, + ESS(**ess_kwargs), rng_key, - init_params=init, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, ) - return mcmc.get_samples(group_by_chain=True) - - rng_keys = jax.random.split( - jax.random.PRNGKey(self._helper.seed['mcmc']), - get_parallel_number(n_parallel), + elif warmup > 0: + print('Warming up') + sampler.warmup( + rng_key=rng_key, + init_params=init, ) - traces = jax.pmap(do_mcmc)(rng_keys) - trace = {k: np.concatenate(v) for k, v in traces.items()} - - sampler = MCMC( - ess_kernel, - num_warmup=warmup, - num_samples=steps, + print('Sampling...') + run_ensemble( + sampler, + ESS(**ess_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, ) - sampler._states = {sampler._sample_field: trace} - else: - sampler = MCMC( + print('Sampling...') + run_ensemble( + sampler, ESS(**ess_kwargs), - num_warmup=warmup, - num_samples=steps, - num_chains=chains, - chain_method=chain_method, - progress_bar=progress, + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, ) - sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - init_params=init, - ) + if filepath is not None: + with open(filepath, 'wb') as f: + dill.dump(sampler.last_state, f) + return PosteriorResult(sampler, self._helper, self) def sa( @@ -1225,6 +1366,8 @@ def sa( init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, + filepath: str | None = None, + resume: bool = False, **sa_kwargs: dict, ) -> PosteriorResult: """Run the Sample Adaptive MCMC of :mod:`numpyro`. @@ -1247,6 +1390,12 @@ def sa( The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bar during sampling. The default is True. + filepath : str, optional + Path to the file where last state are saved. Must have `.pkl` extension. + If None, no file are written. Default is None. + resume : bool, optional + If True, read the last state file from a previous run, and then, + sampling will skip the warmup adaptation phase. Default is True. **sa_kwargs : dict Extra parameters passed to :class:`numpyro.infer.SA`. @@ -1288,7 +1437,82 @@ def sa( progress_bar=progress, ) + if resume: + try: + with open(filepath, 'rb') as f: + last_state = dill.load(f) + print('Load last state file...') + sampler.post_warmup_state = last_state + print('Sampling...') + sampler.run(sampler.post_warmup_state.rng_key) + except: + print('Failed to load last state file.') + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + ) + + elif warmup > 0: + print('Warming up') + sampler.warmup( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + ) + print('Sampling...') + sampler.run(sampler.post_warmup_state.rng_key) + else: + print('Sampling...') + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + ) + + if filepath is not None: + with open(filepath, 'wb') as f: + dill.dump(sampler.last_state, f) + + return PosteriorResult(sampler, self._helper, self) + + +# temporarily for ensemble parallelled run +def run_ensemble( + sampler, + kernel, + rng_key, + warmup, + steps, + chains, + init_params, + chain_method='vectorized', + n_parallel=None, +): + if chain_method == 'parallel': + paral_mcmc = MCMC( + kernel, + num_warmup=warmup, + num_samples=steps, + num_chains=chains, + chain_method='vectorized', + progress_bar=False, + ) + + if sampler.last_state is not None: + paral_mcmc.post_warmup_state = sampler.last_state + + def do_mcmc(rng_key): + paral_mcmc.run( + rng_key, + init_params=init_params, + ) + return paral_mcmc.get_samples(group_by_chain=True) + + rng_keys = jax.random.split( + rng_key, + get_parallel_number(n_parallel), + ) + traces = jax.pmap(do_mcmc)(rng_keys) + trace = {k: np.concatenate(v) for k, v in traces.items()} + sampler._states = {sampler._sample_field: trace} + + else: sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + rng_key=rng_key, + init_params=init_params, ) - return PosteriorResult(sampler, self._helper, self)