Skip to content

Commit

Permalink
Add kwargs argument to sample posterior
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Apr 3, 2023
1 parent 605d5d7 commit 90171f8
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 5 deletions.
4 changes: 3 additions & 1 deletion examples/bivariate_gaussian_snl.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def run():
)

nuts_samples = sample_with_nuts(rng_seq, log_density, 2, 4, 2000, 1000)
snl_samples, _ = snl.sample_posterior(params, 4, 10000, 7500)
snl_samples, _ = snl.sample_posterior(
params, 4, 10000, 7500, sampler="slice"
)

snl_samples = snl_samples.reshape(-1, 2)
nuts_samples = nuts_samples.reshape(-1, 2)
Expand Down
4 changes: 3 additions & 1 deletion examples/slcp_snl_masked_autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def run(use_surjectors):
random.PRNGKey(23), y_observed, optimizer, n_rounds=3, sampler="slice"
)

snl_samples, _ = snl.sample_posterior(params, 4, 20000, 10000)
snl_samples, _ = snl.sample_posterior(
params, 4, 20000, 10000, sampler="slice"
)
snl_samples = snl_samples.reshape(-1, len_theta)

def log_density_fn(theta, y):
Expand Down
2 changes: 1 addition & 1 deletion sbijax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
sbijax: Simulation-based inference in JAX
"""

__version__ = "0.0.8"
__version__ = "0.0.9"


from sbijax.abc.rejection_abc import RejectionABC
Expand Down
11 changes: 9 additions & 2 deletions sbijax/snl.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def fit(
before stopping optimisation
kwargs: keyword arguments with sampler specific parameters. For slice
sampling the following arguments are possible:
- sampler: either 'nuts', 'slice' or None (defaults to nuts)
- n_thin: number of thinning steps
- n_doubling: number of doubling steps of the interval
- step_size: step size of the initial interval
Expand Down Expand Up @@ -139,7 +140,7 @@ def fit(
return params, snl_info(all_params, all_losses, all_diagnostics)

# pylint: disable=arguments-differ
def sample_posterior(self, params, n_chains, n_samples, n_warmup):
def sample_posterior(self, params, n_chains, n_samples, n_warmup, **kwargs):
"""
Sample from the approximate posterior
Expand All @@ -153,6 +154,12 @@ def sample_posterior(self, params, n_chains, n_samples, n_warmup):
number of samples per chain
n_warmup: int
number of samples to discard
kwargs: keyword arguments with sampler specific parameters. For slice
sampling the following arguments are possible:
- sampler: either 'nuts', 'slice' or None (defaults to nuts)
- n_thin: number of thinning steps
- n_doubling: number of doubling steps of the interval
- step_size: step size of the initial interval
Returns
-------
Expand All @@ -162,7 +169,7 @@ def sample_posterior(self, params, n_chains, n_samples, n_warmup):
"""

return self._simulate_from_amortized_posterior(
params, n_chains, n_samples, n_warmup
params, n_chains, n_samples, n_warmup, **kwargs
)

def _fit_model_single_round(
Expand Down

0 comments on commit 90171f8

Please sign in to comment.