diff --git a/examples/bivariate_gaussian_snl.py b/examples/bivariate_gaussian_snl.py index ce924b9..707f6c0 100644 --- a/examples/bivariate_gaussian_snl.py +++ b/examples/bivariate_gaussian_snl.py @@ -92,7 +92,9 @@ def run(): ) nuts_samples = sample_with_nuts(rng_seq, log_density, 2, 4, 2000, 1000) - snl_samples, _ = snl.sample_posterior(params, 4, 10000, 7500) + snl_samples, _ = snl.sample_posterior( + params, 4, 10000, 7500, sampler="slice" + ) snl_samples = snl_samples.reshape(-1, 2) nuts_samples = nuts_samples.reshape(-1, 2) diff --git a/examples/slcp_snl_masked_autoregressive.py b/examples/slcp_snl_masked_autoregressive.py index ef91ffe..6e75fba 100644 --- a/examples/slcp_snl_masked_autoregressive.py +++ b/examples/slcp_snl_masked_autoregressive.py @@ -158,7 +158,9 @@ def run(use_surjectors): random.PRNGKey(23), y_observed, optimizer, n_rounds=3, sampler="slice" ) - snl_samples, _ = snl.sample_posterior(params, 4, 20000, 10000) + snl_samples, _ = snl.sample_posterior( + params, 4, 20000, 10000, sampler="slice" + ) snl_samples = snl_samples.reshape(-1, len_theta) def log_density_fn(theta, y): diff --git a/sbijax/__init__.py b/sbijax/__init__.py index 6e5ff0a..5a235d6 100644 --- a/sbijax/__init__.py +++ b/sbijax/__init__.py @@ -2,7 +2,7 @@ sbijax: Simulation-based inference in JAX """ -__version__ = "0.0.8" +__version__ = "0.0.9" from sbijax.abc.rejection_abc import RejectionABC diff --git a/sbijax/snl.py b/sbijax/snl.py index 5e45d18..91f6372 100644 --- a/sbijax/snl.py +++ b/sbijax/snl.py @@ -90,6 +90,7 @@ def fit( before stopping optimisation kwargs: keyword arguments with sampler specific parameters. For slice sampling the following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) - n_thin: number of thinning steps - n_doubling: number of doubling steps of the interval - step_size: step size of the initial interval @@ -139,7 +140,7 @@ def fit( return params, snl_info(all_params, all_losses, all_diagnostics) # pylint: disable=arguments-differ - def sample_posterior(self, params, n_chains, n_samples, n_warmup): + def sample_posterior(self, params, n_chains, n_samples, n_warmup, **kwargs): """ Sample from the approximate posterior @@ -153,6 +154,12 @@ def sample_posterior(self, params, n_chains, n_samples, n_warmup): number of samples per chain n_warmup: int number of samples to discard + kwargs: keyword arguments with sampler specific parameters. For slice + sampling the following arguments are possible: + - sampler: either 'nuts', 'slice' or None (defaults to nuts) + - n_thin: number of thinning steps + - n_doubling: number of doubling steps of the interval + - step_size: step size of the initial interval Returns ------- @@ -162,7 +169,7 @@ def sample_posterior(self, params, n_chains, n_samples, n_warmup): """ return self._simulate_from_amortized_posterior( - params, n_chains, n_samples, n_warmup + params, n_chains, n_samples, n_warmup, **kwargs ) def _fit_model_single_round(