From 06ee92be6d441aabbfa8408e13c2ba8e78754540 Mon Sep 17 00:00:00 2001 From: "S.-L. Xie" <82627490+xiesl97@users.noreply.github.com> Date: Wed, 27 Nov 2024 13:59:18 +0800 Subject: [PATCH 01/13] save and load warmup state for `NUT` and `SA` --- src/elisa/infer/fit.py | 65 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 58 insertions(+), 7 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index f802eaa..c6c7242 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -7,6 +7,10 @@ from collections.abc import Sequence from typing import TYPE_CHECKING +import bz2 +import gzip +import lzma +import dill import jax import jax.numpy as jnp import nautilus @@ -557,6 +561,8 @@ def nuts( init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, + save_warmup: str | None = None, + load_warmup: str | None = None, **nuts_kwargs: dict, ) -> PosteriorResult: """Run the No-U-Turn Sampler of :mod:`numpyro`. @@ -581,6 +587,10 @@ def nuts( The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bar during sampling. The default is True. + save_warmup : str, optional + Path to save the warmup file. The default is None. + load_warmup : str, optional + Path to load the warmup file. The default is None. **nuts_kwargs : dict Extra parameters passed to :class:`numpyro.infer.NUTS`. @@ -631,10 +641,28 @@ def nuts( progress_bar=progress, ) - sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - extra_fields=('energy', 'num_steps'), - ) + if load_warmup is not None: + with gzip.open(load_warmup, 'rb') as f: + last_state = dill.load(f) + sampler.post_warmup_state = last_state + sampler.run(sampler.post_warmup_state.rng_key) + + elif warmup > 0: + sampler.warmup( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + extra_fields=('energy', 'num_steps'), + ) + sampler.run(sampler.post_warmup_state.rng_key) + else: + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + extra_fields=('energy', 'num_steps'), + ) + + if save_warmup is not None: + with gzip.open(save_warmup, 'wb') as f: + dill.dump(sampler.last_state, f) + return PosteriorResult(sampler, self._helper, self) def jaxns( @@ -1225,6 +1253,8 @@ def sa( init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, + save_warmup: str | None = None, + load_warmup: str | None = None, **sa_kwargs: dict, ) -> PosteriorResult: """Run the Sample Adaptive MCMC of :mod:`numpyro`. @@ -1247,6 +1277,10 @@ def sa( The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bar during sampling. The default is True. + save_warmup : str, optional + Path to save the warmup file. The default is None. + load_warmup : str, optional + Path to load the warmup file. The default is None. **sa_kwargs : dict Extra parameters passed to :class:`numpyro.infer.SA`. @@ -1288,7 +1322,24 @@ def sa( progress_bar=progress, ) - sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - ) + if load_warmup is not None: + with gzip.open(load_warmup, 'rb') as f: + last_state = dill.load(f) + sampler.post_warmup_state = last_state + sampler.run(sampler.post_warmup_state.rng_key) + + elif warmup > 0: + sampler.warmup( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + ) + sampler.run(sampler.post_warmup_state.rng_key) + else: + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + ) + + if save_warmup is not None: + with gzip.open(save_warmup, 'wb') as f: + dill.dump(sampler.last_state, f) + return PosteriorResult(sampler, self._helper, self) From d1cdd743656b3d6173a3de1a6b49408874fdb5e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 06:00:12 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/elisa/infer/fit.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index c6c7242..07bb180 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -2,14 +2,12 @@ from __future__ import annotations +import gzip import time from abc import ABC, abstractmethod from collections.abc import Sequence from typing import TYPE_CHECKING -import bz2 -import gzip -import lzma import dill import jax import jax.numpy as jnp From 52c9e779ea3cf84e6f26e28863432cd4f1263a5b Mon Sep 17 00:00:00 2001 From: "S.-L. Xie" <82627490+xiesl97@users.noreply.github.com> Date: Thu, 28 Nov 2024 13:48:02 +0800 Subject: [PATCH 03/13] remove gzip --- src/elisa/infer/fit.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 07bb180..056a81f 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -2,7 +2,6 @@ from __future__ import annotations -import gzip import time from abc import ABC, abstractmethod from collections.abc import Sequence @@ -640,7 +639,7 @@ def nuts( ) if load_warmup is not None: - with gzip.open(load_warmup, 'rb') as f: + with open(load_warmup, 'rb') as f: last_state = dill.load(f) sampler.post_warmup_state = last_state sampler.run(sampler.post_warmup_state.rng_key) @@ -658,7 +657,7 @@ def nuts( ) if save_warmup is not None: - with gzip.open(save_warmup, 'wb') as f: + with open(save_warmup, 'wb') as f: dill.dump(sampler.last_state, f) return PosteriorResult(sampler, self._helper, self) @@ -1321,7 +1320,7 @@ def sa( ) if load_warmup is not None: - with gzip.open(load_warmup, 'rb') as f: + with open(load_warmup, 'rb') as f: last_state = dill.load(f) sampler.post_warmup_state = last_state sampler.run(sampler.post_warmup_state.rng_key) @@ -1337,7 +1336,7 @@ def sa( ) if save_warmup is not None: - with gzip.open(save_warmup, 'wb') as f: + with open(save_warmup, 'wb') as f: dill.dump(sampler.last_state, f) return PosteriorResult(sampler, self._helper, self) From f996c713aeb0d55c816d1b0853af245df34e8576 Mon Sep 17 00:00:00 2001 From: "S.-L. Xie" <82627490+xiesl97@users.noreply.github.com> Date: Thu, 28 Nov 2024 15:39:39 +0800 Subject: [PATCH 04/13] Update load_warmup --- src/elisa/infer/fit.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 056a81f..b8559ed 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -7,6 +7,7 @@ from collections.abc import Sequence from typing import TYPE_CHECKING +import os import dill import jax import jax.numpy as jnp @@ -639,10 +640,16 @@ def nuts( ) if load_warmup is not None: - with open(load_warmup, 'rb') as f: - last_state = dill.load(f) - sampler.post_warmup_state = last_state - sampler.run(sampler.post_warmup_state.rng_key) + if os.path.exists(load_warmup): + with open(load_warmup, 'rb') as f: + last_state = dill.load(f) + sampler.post_warmup_state = last_state + sampler.run(sampler.post_warmup_state.rng_key) + else: + print(f"{load_warmup} not found!\nRunning sampling...") + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + ) elif warmup > 0: sampler.warmup( @@ -1320,10 +1327,16 @@ def sa( ) if load_warmup is not None: - with open(load_warmup, 'rb') as f: - last_state = dill.load(f) - sampler.post_warmup_state = last_state - sampler.run(sampler.post_warmup_state.rng_key) + if os.path.exists(load_warmup): + with open(load_warmup, 'rb') as f: + last_state = dill.load(f) + sampler.post_warmup_state = last_state + sampler.run(sampler.post_warmup_state.rng_key) + else: + print(f"{load_warmup} not found!\nRunning sampling...") + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + ) elif warmup > 0: sampler.warmup( From 33a12fcde29dc47e6168de98dc6e15e86d970765 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Nov 2024 07:41:20 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/elisa/infer/fit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index b8559ed..09e1570 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -2,12 +2,12 @@ from __future__ import annotations +import os import time from abc import ABC, abstractmethod from collections.abc import Sequence from typing import TYPE_CHECKING -import os import dill import jax import jax.numpy as jnp @@ -646,7 +646,7 @@ def nuts( sampler.post_warmup_state = last_state sampler.run(sampler.post_warmup_state.rng_key) else: - print(f"{load_warmup} not found!\nRunning sampling...") + print(f'{load_warmup} not found!\nRunning sampling...') sampler.run( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), ) @@ -1333,7 +1333,7 @@ def sa( sampler.post_warmup_state = last_state sampler.run(sampler.post_warmup_state.rng_key) else: - print(f"{load_warmup} not found!\nRunning sampling...") + print(f'{load_warmup} not found!\nRunning sampling...') sampler.run( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), ) From 4abbb53af2d1b4d560a1ce94c070d38172c592ca Mon Sep 17 00:00:00 2001 From: "S.-L. Xie" <82627490+xiesl97@users.noreply.github.com> Date: Fri, 29 Nov 2024 15:29:46 +0800 Subject: [PATCH 06/13] Update aies ess parallel --- src/elisa/infer/fit.py | 221 ++++++++++++++++++++++------------------- 1 file changed, 119 insertions(+), 102 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 09e1570..0808b60 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -2,7 +2,6 @@ from __future__ import annotations -import os import time from abc import ABC, abstractmethod from collections.abc import Sequence @@ -559,8 +558,6 @@ def nuts( init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, - save_warmup: str | None = None, - load_warmup: str | None = None, **nuts_kwargs: dict, ) -> PosteriorResult: """Run the No-U-Turn Sampler of :mod:`numpyro`. @@ -585,10 +582,6 @@ def nuts( The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bar during sampling. The default is True. - save_warmup : str, optional - Path to save the warmup file. The default is None. - load_warmup : str, optional - Path to load the warmup file. The default is None. **nuts_kwargs : dict Extra parameters passed to :class:`numpyro.infer.NUTS`. @@ -639,34 +632,10 @@ def nuts( progress_bar=progress, ) - if load_warmup is not None: - if os.path.exists(load_warmup): - with open(load_warmup, 'rb') as f: - last_state = dill.load(f) - sampler.post_warmup_state = last_state - sampler.run(sampler.post_warmup_state.rng_key) - else: - print(f'{load_warmup} not found!\nRunning sampling...') - sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - ) - - elif warmup > 0: - sampler.warmup( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - extra_fields=('energy', 'num_steps'), - ) - sampler.run(sampler.post_warmup_state.rng_key) - else: - sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - extra_fields=('energy', 'num_steps'), - ) - - if save_warmup is not None: - with open(save_warmup, 'wb') as f: - dill.dump(sampler.last_state, f) - + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + extra_fields=('energy', 'num_steps'), + ) return PosteriorResult(sampler, self._helper, self) def jaxns( @@ -989,6 +958,7 @@ def aies( chain_method: str = 'vectorized', n_parallel: int | None = None, progress: bool = True, + resume_sample: str | None = None, moves: dict | None = None, **aies_kwargs: dict, ) -> PosteriorResult: @@ -1023,7 +993,12 @@ def aies( ``"parallel"``. Defaults to ``jax.local_device_count()``. progress : bool, optional Whether to show progress bar during sampling. The default is True. - If `chain_method` is set to ``'parallel'``, this is always False. + If `chain_method` is set to ``'parallel'``, this is + always False after warmup. + resume_sample : str, optional + Read the last_state file from a previous run, and then, sampling + will skip the warmup adaptation phase. Finally, it saves last_state, + whether there is a last_state file or not. moves : dict, optional Moves for the sampler. **aies_kwargs : dict @@ -1068,23 +1043,52 @@ def aies( else: aies_kwargs['moves'] = moves + # warmup at least 10 + warmup = 10 if warmup<10 else warmup + + sampler = MCMC( + AIES(**aies_kwargs), + num_warmup=warmup, + num_samples=steps, + num_chains=chains, + chain_method='vectorized', + progress_bar=progress, + ) + + if resume_sample is not None: + try: + with open(resume_sample, 'rb') as f: + last_state = dill.load(f) + sampler.post_warmup_state = last_state + except: + sampler.warmup( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + init_params=init, + ) + else: + sampler.warmup( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + init_params=init, + ) + if chain_method == 'parallel': - aies_kernel = AIES(**aies_kwargs) + print('Parallel sampling...') + paral_mcmc = MCMC( + AIES(**aies_kwargs), + num_warmup=warmup, + num_samples=steps, + num_chains=chains, + chain_method='vectorized', + progress_bar=False, + ) + paral_mcmc.post_warmup_state = sampler.last_state def do_mcmc(rng_key): - mcmc = MCMC( - aies_kernel, - num_warmup=warmup, - num_samples=steps, - num_chains=chains, - chain_method='vectorized', - progress_bar=False, - ) - mcmc.run( + paral_mcmc.run( rng_key, init_params=init, ) - return mcmc.get_samples(group_by_chain=True) + return paral_mcmc.get_samples(group_by_chain=True) rng_keys = jax.random.split( jax.random.PRNGKey(self._helper.seed['mcmc']), @@ -1092,28 +1096,18 @@ def do_mcmc(rng_key): ) traces = jax.pmap(do_mcmc)(rng_keys) trace = {k: np.concatenate(v) for k, v in traces.items()} - - sampler = MCMC( - aies_kernel, - num_warmup=warmup, - num_samples=steps, - ) sampler._states = {sampler._sample_field: trace} else: - sampler = MCMC( - AIES(**aies_kwargs), - num_warmup=warmup, - num_samples=steps, - num_chains=chains, - chain_method=chain_method, - progress_bar=progress, - ) - sampler.run( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), init_params=init, ) + + if resume_sample is not None: + with open(resume_sample, 'wb') as f: + dill.dump(sampler.last_state, f) + return PosteriorResult(sampler, self._helper, self) def ess( @@ -1126,6 +1120,7 @@ def ess( n_parallel: int | None = None, progress: bool = True, moves: dict | None = None, + resume_sample: str | None = None, **ess_kwargs: dict, ) -> PosteriorResult: """Ensemble Slice Sampling (ESS) of :mod:`numpyro`. @@ -1159,7 +1154,12 @@ def ess( ``"parallel"``. Defaults to ``jax.local_device_count()``. progress : bool, optional Whether to show progress bar during sampling. The default is True. - If `chain_method` is set to ``'parallel'``, this is always False. + If `chain_method` is set to ``'parallel'``, this is + always False after warmup. + resume_sample : str, optional + Read the last_state file from a previous run, and then, sampling + will skip the warmup adaptation phase. Finally, it saves last_state, + whether there is a last_state file or not. moves : dict, optional Moves for the sampler. **ess_kwargs : dict @@ -1201,23 +1201,52 @@ def ess( else: ess_kwargs['moves'] = moves + # warmup at least 10 + warmup = 10 if warmup<10 else warmup + + sampler = MCMC( + ESS(**ess_kwargs), + num_warmup=warmup, + num_samples=steps, + num_chains=chains, + chain_method='vectorized', + progress_bar=progress, + ) + + if resume_sample is not None: + try: + with open(resume_sample, 'rb') as f: + last_state = dill.load(f) + sampler.post_warmup_state = last_state + except: + sampler.warmup( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + init_params=init, + ) + else: + sampler.warmup( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + init_params=init, + ) + if chain_method == 'parallel': - ess_kernel = ESS(**ess_kwargs) + print('Parallel sampling...') + paral_mcmc = MCMC( + ESS(**ess_kwargs), + num_warmup=warmup, + num_samples=steps, + num_chains=chains, + chain_method='vectorized', + progress_bar=False, + ) + paral_mcmc.post_warmup_state = sampler.last_state def do_mcmc(rng_key): - mcmc = MCMC( - ess_kernel, - num_warmup=warmup, - num_samples=steps, - num_chains=chains, - chain_method='vectorized', - progress_bar=False, - ) - mcmc.run( + paral_mcmc.run( rng_key, init_params=init, ) - return mcmc.get_samples(group_by_chain=True) + return paral_mcmc.get_samples(group_by_chain=True) rng_keys = jax.random.split( jax.random.PRNGKey(self._helper.seed['mcmc']), @@ -1225,28 +1254,18 @@ def do_mcmc(rng_key): ) traces = jax.pmap(do_mcmc)(rng_keys) trace = {k: np.concatenate(v) for k, v in traces.items()} - - sampler = MCMC( - ess_kernel, - num_warmup=warmup, - num_samples=steps, - ) sampler._states = {sampler._sample_field: trace} else: - sampler = MCMC( - ESS(**ess_kwargs), - num_warmup=warmup, - num_samples=steps, - num_chains=chains, - chain_method=chain_method, - progress_bar=progress, - ) - sampler.run( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), init_params=init, ) + + if resume_sample is not None: + with open(resume_sample, 'wb') as f: + dill.dump(sampler.last_state, f) + return PosteriorResult(sampler, self._helper, self) def sa( @@ -1257,8 +1276,7 @@ def sa( init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, - save_warmup: str | None = None, - load_warmup: str | None = None, + resume_sample: str | None = None, **sa_kwargs: dict, ) -> PosteriorResult: """Run the Sample Adaptive MCMC of :mod:`numpyro`. @@ -1281,10 +1299,10 @@ def sa( The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bar during sampling. The default is True. - save_warmup : str, optional - Path to save the warmup file. The default is None. - load_warmup : str, optional - Path to load the warmup file. The default is None. + resume_sample : str, optional + Read the last_state file from a previous run, and then, sampling + will skip the warmup adaptation phase. Finally, it saves last_state, + whether there is a last_state file or not. **sa_kwargs : dict Extra parameters passed to :class:`numpyro.infer.SA`. @@ -1326,14 +1344,13 @@ def sa( progress_bar=progress, ) - if load_warmup is not None: - if os.path.exists(load_warmup): - with open(load_warmup, 'rb') as f: + if resume_sample is not None: + try: + with open(resume_sample, 'rb') as f: last_state = dill.load(f) sampler.post_warmup_state = last_state sampler.run(sampler.post_warmup_state.rng_key) - else: - print(f'{load_warmup} not found!\nRunning sampling...') + except: sampler.run( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), ) @@ -1348,8 +1365,8 @@ def sa( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), ) - if save_warmup is not None: - with open(save_warmup, 'wb') as f: + if resume_sample is not None: + with open(resume_sample, 'wb') as f: dill.dump(sampler.last_state, f) return PosteriorResult(sampler, self._helper, self) From 8dbde46cef02b0818a7ea21ceb7ce12748754108 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Nov 2024 07:29:52 +0000 Subject: [PATCH 07/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/elisa/infer/fit.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 0808b60..ea510dc 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -996,8 +996,8 @@ def aies( If `chain_method` is set to ``'parallel'``, this is always False after warmup. resume_sample : str, optional - Read the last_state file from a previous run, and then, sampling - will skip the warmup adaptation phase. Finally, it saves last_state, + Read the last_state file from a previous run, and then, sampling + will skip the warmup adaptation phase. Finally, it saves last_state, whether there is a last_state file or not. moves : dict, optional Moves for the sampler. @@ -1044,7 +1044,7 @@ def aies( aies_kwargs['moves'] = moves # warmup at least 10 - warmup = 10 if warmup<10 else warmup + warmup = 10 if warmup < 10 else warmup sampler = MCMC( AIES(**aies_kwargs), @@ -1107,7 +1107,7 @@ def do_mcmc(rng_key): if resume_sample is not None: with open(resume_sample, 'wb') as f: dill.dump(sampler.last_state, f) - + return PosteriorResult(sampler, self._helper, self) def ess( @@ -1157,8 +1157,8 @@ def ess( If `chain_method` is set to ``'parallel'``, this is always False after warmup. resume_sample : str, optional - Read the last_state file from a previous run, and then, sampling - will skip the warmup adaptation phase. Finally, it saves last_state, + Read the last_state file from a previous run, and then, sampling + will skip the warmup adaptation phase. Finally, it saves last_state, whether there is a last_state file or not. moves : dict, optional Moves for the sampler. @@ -1202,7 +1202,7 @@ def ess( ess_kwargs['moves'] = moves # warmup at least 10 - warmup = 10 if warmup<10 else warmup + warmup = 10 if warmup < 10 else warmup sampler = MCMC( ESS(**ess_kwargs), @@ -1300,8 +1300,8 @@ def sa( progress : bool, optional Whether to show progress bar during sampling. The default is True. resume_sample : str, optional - Read the last_state file from a previous run, and then, sampling - will skip the warmup adaptation phase. Finally, it saves last_state, + Read the last_state file from a previous run, and then, sampling + will skip the warmup adaptation phase. Finally, it saves last_state, whether there is a last_state file or not. **sa_kwargs : dict Extra parameters passed to :class:`numpyro.infer.SA`. From 176891edf068ab9661c839399861cb96ec9c7551 Mon Sep 17 00:00:00 2001 From: "S.-L. Xie" <82627490+xiesl97@users.noreply.github.com> Date: Sat, 30 Nov 2024 15:26:39 +0800 Subject: [PATCH 08/13] temporarily ensemble parallelled run --- src/elisa/infer/fit.py | 154 ++++++++++++++++++----------------------- 1 file changed, 67 insertions(+), 87 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index ea510dc..039da95 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -1043,8 +1043,7 @@ def aies( else: aies_kwargs['moves'] = moves - # warmup at least 10 - warmup = 10 if warmup < 10 else warmup + rng_key = jax.random.PRNGKey(self._helper.seed['mcmc']) sampler = MCMC( AIES(**aies_kwargs), @@ -1060,49 +1059,20 @@ def aies( with open(resume_sample, 'rb') as f: last_state = dill.load(f) sampler.post_warmup_state = last_state + run_ensemble(sampler, AIES(**aies_kwargs), rng_key, warmup, steps, chains, + init, chain_method, n_parallel) except: - sampler.warmup( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - init_params=init, - ) - else: - sampler.warmup( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - init_params=init, - ) - - if chain_method == 'parallel': - print('Parallel sampling...') - paral_mcmc = MCMC( - AIES(**aies_kwargs), - num_warmup=warmup, - num_samples=steps, - num_chains=chains, - chain_method='vectorized', - progress_bar=False, - ) - paral_mcmc.post_warmup_state = sampler.last_state - - def do_mcmc(rng_key): - paral_mcmc.run( - rng_key, - init_params=init, - ) - return paral_mcmc.get_samples(group_by_chain=True) - - rng_keys = jax.random.split( - jax.random.PRNGKey(self._helper.seed['mcmc']), - get_parallel_number(n_parallel), - ) - traces = jax.pmap(do_mcmc)(rng_keys) - trace = {k: np.concatenate(v) for k, v in traces.items()} - sampler._states = {sampler._sample_field: trace} - + if warmup>0: + sampler.warmup(rng_key=rng_key, init_params=init,) + run_ensemble(sampler, AIES(**aies_kwargs), rng_key, warmup, steps, chains, + init, chain_method, n_parallel) + elif warmup > 0: + sampler.warmup(rng_key=rng_key, init_params=init,) + run_ensemble(sampler, AIES(**aies_kwargs), rng_key, warmup, steps, chains, + init, chain_method, n_parallel) else: - sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - init_params=init, - ) + run_ensemble(sampler, AIES(**aies_kwargs), rng_key, warmup, steps, chains, + init, chain_method, n_parallel) if resume_sample is not None: with open(resume_sample, 'wb') as f: @@ -1200,9 +1170,8 @@ def ess( ess_kwargs['moves'] = {ESS.DifferentialMove(): 1.0} else: ess_kwargs['moves'] = moves - - # warmup at least 10 - warmup = 10 if warmup < 10 else warmup + + rng_key = jax.random.PRNGKey(self._helper.seed['mcmc']) sampler = MCMC( ESS(**ess_kwargs), @@ -1218,49 +1187,20 @@ def ess( with open(resume_sample, 'rb') as f: last_state = dill.load(f) sampler.post_warmup_state = last_state + run_ensemble(sampler, ESS(**ess_kwargs), rng_key, warmup, steps, chains, + init, chain_method, n_parallel) except: - sampler.warmup( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - init_params=init, - ) - else: - sampler.warmup( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - init_params=init, - ) - - if chain_method == 'parallel': - print('Parallel sampling...') - paral_mcmc = MCMC( - ESS(**ess_kwargs), - num_warmup=warmup, - num_samples=steps, - num_chains=chains, - chain_method='vectorized', - progress_bar=False, - ) - paral_mcmc.post_warmup_state = sampler.last_state - - def do_mcmc(rng_key): - paral_mcmc.run( - rng_key, - init_params=init, - ) - return paral_mcmc.get_samples(group_by_chain=True) - - rng_keys = jax.random.split( - jax.random.PRNGKey(self._helper.seed['mcmc']), - get_parallel_number(n_parallel), - ) - traces = jax.pmap(do_mcmc)(rng_keys) - trace = {k: np.concatenate(v) for k, v in traces.items()} - sampler._states = {sampler._sample_field: trace} - + if warmup>0: + sampler.warmup(rng_key=rng_key, init_params=init,) + run_ensemble(sampler, ESS(**ess_kwargs), rng_key, warmup, steps, chains, + init, chain_method, n_parallel) + elif warmup > 0: + sampler.warmup(rng_key=rng_key, init_params=init,) + run_ensemble(sampler, ESS(**ess_kwargs), rng_key, warmup, steps, chains, + init, chain_method, n_parallel) else: - sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - init_params=init, - ) + run_ensemble(sampler, ESS(**ess_kwargs), rng_key, warmup, steps, chains, + init, chain_method, n_parallel) if resume_sample is not None: with open(resume_sample, 'wb') as f: @@ -1370,3 +1310,43 @@ def sa( dill.dump(sampler.last_state, f) return PosteriorResult(sampler, self._helper, self) + + +# temporarily for ensemble parallelled run +def run_ensemble(sampler, kernel, rng_key, warmup, steps, chains, init_params, + chain_method = 'vectorized', n_parallel=None): + if chain_method == 'parallel': + print('Parallel sampling...') + paral_mcmc = MCMC( + kernel, + num_warmup=warmup, + num_samples=steps, + num_chains=chains, + chain_method='vectorized', + progress_bar=False, + ) + + if sampler.last_state is not None: + paral_mcmc.post_warmup_state = sampler.last_state + + def do_mcmc(rng_key): + paral_mcmc.run( + rng_key, + init_params=init_params, + ) + return paral_mcmc.get_samples(group_by_chain=True) + + rng_keys = jax.random.split( + rng_key, + get_parallel_number(n_parallel), + ) + traces = jax.pmap(do_mcmc)(rng_keys) + trace = {k: np.concatenate(v) for k, v in traces.items()} + sampler._states = {sampler._sample_field: trace} + + else: + sampler.run( + rng_key=rng_key, + init_params=init_params, + ) + From 916262c44116b62fc087688c9655f848e33b75b1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 30 Nov 2024 07:26:46 +0000 Subject: [PATCH 09/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/elisa/infer/fit.py | 144 +++++++++++++++++++++++++++++++++-------- 1 file changed, 118 insertions(+), 26 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 039da95..aeca818 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -1059,20 +1059,62 @@ def aies( with open(resume_sample, 'rb') as f: last_state = dill.load(f) sampler.post_warmup_state = last_state - run_ensemble(sampler, AIES(**aies_kwargs), rng_key, warmup, steps, chains, - init, chain_method, n_parallel) + run_ensemble( + sampler, + AIES(**aies_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, + ) except: - if warmup>0: - sampler.warmup(rng_key=rng_key, init_params=init,) - run_ensemble(sampler, AIES(**aies_kwargs), rng_key, warmup, steps, chains, - init, chain_method, n_parallel) + if warmup > 0: + sampler.warmup( + rng_key=rng_key, + init_params=init, + ) + run_ensemble( + sampler, + AIES(**aies_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, + ) elif warmup > 0: - sampler.warmup(rng_key=rng_key, init_params=init,) - run_ensemble(sampler, AIES(**aies_kwargs), rng_key, warmup, steps, chains, - init, chain_method, n_parallel) + sampler.warmup( + rng_key=rng_key, + init_params=init, + ) + run_ensemble( + sampler, + AIES(**aies_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, + ) else: - run_ensemble(sampler, AIES(**aies_kwargs), rng_key, warmup, steps, chains, - init, chain_method, n_parallel) + run_ensemble( + sampler, + AIES(**aies_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, + ) if resume_sample is not None: with open(resume_sample, 'wb') as f: @@ -1170,7 +1212,7 @@ def ess( ess_kwargs['moves'] = {ESS.DifferentialMove(): 1.0} else: ess_kwargs['moves'] = moves - + rng_key = jax.random.PRNGKey(self._helper.seed['mcmc']) sampler = MCMC( @@ -1187,20 +1229,62 @@ def ess( with open(resume_sample, 'rb') as f: last_state = dill.load(f) sampler.post_warmup_state = last_state - run_ensemble(sampler, ESS(**ess_kwargs), rng_key, warmup, steps, chains, - init, chain_method, n_parallel) + run_ensemble( + sampler, + ESS(**ess_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, + ) except: - if warmup>0: - sampler.warmup(rng_key=rng_key, init_params=init,) - run_ensemble(sampler, ESS(**ess_kwargs), rng_key, warmup, steps, chains, - init, chain_method, n_parallel) + if warmup > 0: + sampler.warmup( + rng_key=rng_key, + init_params=init, + ) + run_ensemble( + sampler, + ESS(**ess_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, + ) elif warmup > 0: - sampler.warmup(rng_key=rng_key, init_params=init,) - run_ensemble(sampler, ESS(**ess_kwargs), rng_key, warmup, steps, chains, - init, chain_method, n_parallel) + sampler.warmup( + rng_key=rng_key, + init_params=init, + ) + run_ensemble( + sampler, + ESS(**ess_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, + ) else: - run_ensemble(sampler, ESS(**ess_kwargs), rng_key, warmup, steps, chains, - init, chain_method, n_parallel) + run_ensemble( + sampler, + ESS(**ess_kwargs), + rng_key, + warmup, + steps, + chains, + init, + chain_method, + n_parallel, + ) if resume_sample is not None: with open(resume_sample, 'wb') as f: @@ -1313,8 +1397,17 @@ def sa( # temporarily for ensemble parallelled run -def run_ensemble(sampler, kernel, rng_key, warmup, steps, chains, init_params, - chain_method = 'vectorized', n_parallel=None): +def run_ensemble( + sampler, + kernel, + rng_key, + warmup, + steps, + chains, + init_params, + chain_method='vectorized', + n_parallel=None, +): if chain_method == 'parallel': print('Parallel sampling...') paral_mcmc = MCMC( @@ -1349,4 +1442,3 @@ def do_mcmc(rng_key): rng_key=rng_key, init_params=init_params, ) - From 2d714610fb2c896d33fcb6c3a8a6830b8157edb3 Mon Sep 17 00:00:00 2001 From: "S.-L. Xie" <82627490+xiesl97@users.noreply.github.com> Date: Sun, 1 Dec 2024 15:29:04 +0800 Subject: [PATCH 10/13] Update checkpoint code --- src/elisa/infer/fit.py | 66 +++++++++++++++++++++++++----------------- 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index aeca818..ee9873a 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -958,7 +958,8 @@ def aies( chain_method: str = 'vectorized', n_parallel: int | None = None, progress: bool = True, - resume_sample: str | None = None, + filepath: str | None = None, + resume: bool = False, moves: dict | None = None, **aies_kwargs: dict, ) -> PosteriorResult: @@ -995,10 +996,12 @@ def aies( Whether to show progress bar during sampling. The default is True. If `chain_method` is set to ``'parallel'``, this is always False after warmup. - resume_sample : str, optional - Read the last_state file from a previous run, and then, sampling - will skip the warmup adaptation phase. Finally, it saves last_state, - whether there is a last_state file or not. + filepath : str, optional + Path to the file where last_state are saved. Must have `.pkl` extension. + If None, no file are written. Default is None. + resume : bool, optional + If True, read the last_state file from a previous run, and then, + sampling will skip the warmup adaptation phase. Default is True. moves : dict, optional Moves for the sampler. **aies_kwargs : dict @@ -1054,9 +1057,9 @@ def aies( progress_bar=progress, ) - if resume_sample is not None: + if resume: try: - with open(resume_sample, 'rb') as f: + with open(filepath, 'rb') as f: last_state = dill.load(f) sampler.post_warmup_state = last_state run_ensemble( @@ -1071,6 +1074,7 @@ def aies( n_parallel, ) except: + print("No last_state file found. Sampling...") if warmup > 0: sampler.warmup( rng_key=rng_key, @@ -1116,8 +1120,8 @@ def aies( n_parallel, ) - if resume_sample is not None: - with open(resume_sample, 'wb') as f: + if filepath is not None: + with open(filepath, 'wb') as f: dill.dump(sampler.last_state, f) return PosteriorResult(sampler, self._helper, self) @@ -1132,7 +1136,8 @@ def ess( n_parallel: int | None = None, progress: bool = True, moves: dict | None = None, - resume_sample: str | None = None, + filepath: str | None = None, + resume: bool = False, **ess_kwargs: dict, ) -> PosteriorResult: """Ensemble Slice Sampling (ESS) of :mod:`numpyro`. @@ -1168,10 +1173,12 @@ def ess( Whether to show progress bar during sampling. The default is True. If `chain_method` is set to ``'parallel'``, this is always False after warmup. - resume_sample : str, optional - Read the last_state file from a previous run, and then, sampling - will skip the warmup adaptation phase. Finally, it saves last_state, - whether there is a last_state file or not. + filepath : str, optional + Path to the file where last_state are saved. Must have `.pkl` extension. + If None, no file are written. Default is None. + resume : bool, optional + If True, read the last_state file from a previous run, and then, + sampling will skip the warmup adaptation phase. Default is True. moves : dict, optional Moves for the sampler. **ess_kwargs : dict @@ -1224,9 +1231,9 @@ def ess( progress_bar=progress, ) - if resume_sample is not None: + if resume: try: - with open(resume_sample, 'rb') as f: + with open(filepath, 'rb') as f: last_state = dill.load(f) sampler.post_warmup_state = last_state run_ensemble( @@ -1241,6 +1248,7 @@ def ess( n_parallel, ) except: + print("No last_state file found. Sampling...") if warmup > 0: sampler.warmup( rng_key=rng_key, @@ -1286,8 +1294,8 @@ def ess( n_parallel, ) - if resume_sample is not None: - with open(resume_sample, 'wb') as f: + if filepath is not None: + with open(filepath, 'wb') as f: dill.dump(sampler.last_state, f) return PosteriorResult(sampler, self._helper, self) @@ -1300,7 +1308,8 @@ def sa( init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, - resume_sample: str | None = None, + filepath: str | None = None, + resume: bool = False, **sa_kwargs: dict, ) -> PosteriorResult: """Run the Sample Adaptive MCMC of :mod:`numpyro`. @@ -1323,10 +1332,12 @@ def sa( The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bar during sampling. The default is True. - resume_sample : str, optional - Read the last_state file from a previous run, and then, sampling - will skip the warmup adaptation phase. Finally, it saves last_state, - whether there is a last_state file or not. + filepath : str, optional + Path to the file where last_state are saved. Must have `.pkl` extension. + If None, no file are written. Default is None. + resume : bool, optional + If True, read the last_state file from a previous run, and then, + sampling will skip the warmup adaptation phase. Default is True. **sa_kwargs : dict Extra parameters passed to :class:`numpyro.infer.SA`. @@ -1368,13 +1379,14 @@ def sa( progress_bar=progress, ) - if resume_sample is not None: + if resume: try: - with open(resume_sample, 'rb') as f: + with open(filepath, 'rb') as f: last_state = dill.load(f) sampler.post_warmup_state = last_state sampler.run(sampler.post_warmup_state.rng_key) except: + print("No last_state file found. Sampling...") sampler.run( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), ) @@ -1389,8 +1401,8 @@ def sa( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), ) - if resume_sample is not None: - with open(resume_sample, 'wb') as f: + if filepath is not None: + with open(filepath, 'wb') as f: dill.dump(sampler.last_state, f) return PosteriorResult(sampler, self._helper, self) From db43279502119419023fa86d7e2df8608d0ccd9b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 1 Dec 2024 07:29:11 +0000 Subject: [PATCH 11/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/elisa/infer/fit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index ee9873a..64d8c2d 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -1074,7 +1074,7 @@ def aies( n_parallel, ) except: - print("No last_state file found. Sampling...") + print('No last_state file found. Sampling...') if warmup > 0: sampler.warmup( rng_key=rng_key, @@ -1248,7 +1248,7 @@ def ess( n_parallel, ) except: - print("No last_state file found. Sampling...") + print('No last_state file found. Sampling...') if warmup > 0: sampler.warmup( rng_key=rng_key, @@ -1386,7 +1386,7 @@ def sa( sampler.post_warmup_state = last_state sampler.run(sampler.post_warmup_state.rng_key) except: - print("No last_state file found. Sampling...") + print('No last_state file found. Sampling...') sampler.run( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), ) From 3f52dd54e924a698cc03bd216409c425af9de048 Mon Sep 17 00:00:00 2001 From: "S.-L. Xie" <82627490+xiesl97@users.noreply.github.com> Date: Thu, 5 Dec 2024 21:41:07 +0800 Subject: [PATCH 12/13] nuts warmup --- src/elisa/infer/fit.py | 92 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 75 insertions(+), 17 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 64d8c2d..6e52815 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -558,6 +558,8 @@ def nuts( init: dict[str, float] | None = None, chain_method: str = 'parallel', progress: bool = True, + filepath: str | None = None, + resume: bool = False, **nuts_kwargs: dict, ) -> PosteriorResult: """Run the No-U-Turn Sampler of :mod:`numpyro`. @@ -582,6 +584,12 @@ def nuts( The chain method passed to :class:`numpyro.infer.MCMC`. progress : bool, optional Whether to show progress bar during sampling. The default is True. + filepath : str, optional + Path to the file where last state are saved. Must have `.pkl` extension. + If None, no file are written. Default is None. + resume : bool, optional + If True, read the last state file from a previous run, and then, + sampling will skip the warmup adaptation phase. Default is True. **nuts_kwargs : dict Extra parameters passed to :class:`numpyro.infer.NUTS`. @@ -632,10 +640,42 @@ def nuts( progress_bar=progress, ) - sampler.run( - rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), - extra_fields=('energy', 'num_steps'), - ) + if resume: + try: + with open(filepath, 'rb') as f: + last_state = dill.load(f) + print('Load last state file...') + sampler.post_warmup_state = last_state + print('Sampling...') + sampler.run(rng_key=sampler.post_warmup_state.rng_key, + extra_fields=('energy', 'num_steps'),) + except: + print('Failed to load last state file.') + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + extra_fields=('energy', 'num_steps'), + ) + + elif warmup > 0: + print('Warming up') + sampler.warmup( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + extra_fields=('energy', 'num_steps'), + ) + print('Sampling...') + sampler.run(rng_key=sampler.post_warmup_state.rng_key, + extra_fields=('energy', 'num_steps'),) + else: + print('Sampling...') + sampler.run( + rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), + extra_fields=('energy', 'num_steps'), + ) + + if filepath is not None: + with open(filepath, 'wb') as f: + dill.dump(sampler.last_state, f) + return PosteriorResult(sampler, self._helper, self) def jaxns( @@ -725,7 +765,7 @@ def jaxns( termination_kwargs=termination_kwargs, ) - print('Start nested sampling...') + print('Start nested') t0 = time.time() sampler.run(rng_key=jax.random.PRNGKey(self._helper.seed['mcmc'])) print(f'Sampling cost {time.time() - t0:.2f} s') @@ -826,7 +866,7 @@ def transform_(samples): sampler._transform_back = transform_ if read_file is None: - print('Start nested sampling...') + print('Start nested') t0 = time.time() sampler.run(min_ess=int(ess), **termination_kwargs) print(f'Sampling cost {time.time() - t0:.2f} s') @@ -935,7 +975,7 @@ def transform_(samples): termination_kwargs['discard_exploration'] = True termination_kwargs.setdefault('verbose', True) - print('Start nested sampling...') + print('Start nested') t0 = time.time() success = sampler.run(n_eff=int(ess), **termination_kwargs) if success: @@ -997,10 +1037,10 @@ def aies( If `chain_method` is set to ``'parallel'``, this is always False after warmup. filepath : str, optional - Path to the file where last_state are saved. Must have `.pkl` extension. + Path to the file where last state are saved. Must have `.pkl` extension. If None, no file are written. Default is None. resume : bool, optional - If True, read the last_state file from a previous run, and then, + If True, read the last state file from a previous run, and then, sampling will skip the warmup adaptation phase. Default is True. moves : dict, optional Moves for the sampler. @@ -1061,7 +1101,9 @@ def aies( try: with open(filepath, 'rb') as f: last_state = dill.load(f) + print('Load last state file...') sampler.post_warmup_state = last_state + print('Sampling...') run_ensemble( sampler, AIES(**aies_kwargs), @@ -1074,12 +1116,14 @@ def aies( n_parallel, ) except: - print('No last_state file found. Sampling...') + print('Failed to load last state file.') if warmup > 0: + print('Warming up') sampler.warmup( rng_key=rng_key, init_params=init, ) + print('Sampling...') run_ensemble( sampler, AIES(**aies_kwargs), @@ -1092,10 +1136,12 @@ def aies( n_parallel, ) elif warmup > 0: + print('Warming up') sampler.warmup( rng_key=rng_key, init_params=init, ) + print('Sampling...') run_ensemble( sampler, AIES(**aies_kwargs), @@ -1108,6 +1154,7 @@ def aies( n_parallel, ) else: + print('Sampling...') run_ensemble( sampler, AIES(**aies_kwargs), @@ -1174,10 +1221,10 @@ def ess( If `chain_method` is set to ``'parallel'``, this is always False after warmup. filepath : str, optional - Path to the file where last_state are saved. Must have `.pkl` extension. + Path to the file where last state are saved. Must have `.pkl` extension. If None, no file are written. Default is None. resume : bool, optional - If True, read the last_state file from a previous run, and then, + If True, read the last state file from a previous run, and then, sampling will skip the warmup adaptation phase. Default is True. moves : dict, optional Moves for the sampler. @@ -1235,7 +1282,9 @@ def ess( try: with open(filepath, 'rb') as f: last_state = dill.load(f) + print('Load last state file...') sampler.post_warmup_state = last_state + print('Sampling...') run_ensemble( sampler, ESS(**ess_kwargs), @@ -1248,12 +1297,14 @@ def ess( n_parallel, ) except: - print('No last_state file found. Sampling...') + print('Failed to load last state file.') if warmup > 0: + print('Warming up') sampler.warmup( rng_key=rng_key, init_params=init, ) + print('Sampling...') run_ensemble( sampler, ESS(**ess_kwargs), @@ -1266,10 +1317,12 @@ def ess( n_parallel, ) elif warmup > 0: + print('Warming up') sampler.warmup( rng_key=rng_key, init_params=init, ) + print('Sampling...') run_ensemble( sampler, ESS(**ess_kwargs), @@ -1282,6 +1335,7 @@ def ess( n_parallel, ) else: + print('Sampling...') run_ensemble( sampler, ESS(**ess_kwargs), @@ -1333,10 +1387,10 @@ def sa( progress : bool, optional Whether to show progress bar during sampling. The default is True. filepath : str, optional - Path to the file where last_state are saved. Must have `.pkl` extension. + Path to the file where last state are saved. Must have `.pkl` extension. If None, no file are written. Default is None. resume : bool, optional - If True, read the last_state file from a previous run, and then, + If True, read the last state file from a previous run, and then, sampling will skip the warmup adaptation phase. Default is True. **sa_kwargs : dict Extra parameters passed to :class:`numpyro.infer.SA`. @@ -1383,20 +1437,25 @@ def sa( try: with open(filepath, 'rb') as f: last_state = dill.load(f) + print('Load last state file...') sampler.post_warmup_state = last_state + print('Sampling...') sampler.run(sampler.post_warmup_state.rng_key) except: - print('No last_state file found. Sampling...') + print('Failed to load last state file.') sampler.run( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), ) elif warmup > 0: + print('Warming up') sampler.warmup( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), ) + print('Sampling...') sampler.run(sampler.post_warmup_state.rng_key) else: + print('Sampling...') sampler.run( rng_key=jax.random.PRNGKey(self._helper.seed['mcmc']), ) @@ -1421,7 +1480,6 @@ def run_ensemble( n_parallel=None, ): if chain_method == 'parallel': - print('Parallel sampling...') paral_mcmc = MCMC( kernel, num_warmup=warmup, From 4d7a635c845f853c563c04f814fb4262bf8cee4f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Dec 2024 13:41:15 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/elisa/infer/fit.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 6e52815..0dced07 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -647,8 +647,10 @@ def nuts( print('Load last state file...') sampler.post_warmup_state = last_state print('Sampling...') - sampler.run(rng_key=sampler.post_warmup_state.rng_key, - extra_fields=('energy', 'num_steps'),) + sampler.run( + rng_key=sampler.post_warmup_state.rng_key, + extra_fields=('energy', 'num_steps'), + ) except: print('Failed to load last state file.') sampler.run( @@ -663,8 +665,10 @@ def nuts( extra_fields=('energy', 'num_steps'), ) print('Sampling...') - sampler.run(rng_key=sampler.post_warmup_state.rng_key, - extra_fields=('energy', 'num_steps'),) + sampler.run( + rng_key=sampler.post_warmup_state.rng_key, + extra_fields=('energy', 'num_steps'), + ) else: print('Sampling...') sampler.run(