Skip to content

Commit

Permalink
Add minor functionality (#29)
Browse files Browse the repository at this point in the history
* Update makefile
* Add new functions for sampling
  • Loading branch information
dirmeier authored Apr 22, 2024
1 parent 5851493 commit 3f7edcb
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 13 deletions.
14 changes: 14 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
.PHONY: tag
.PHONY: tests
.PHONY: lints
.PHONY: docs

PKG_VERSION=`hatch version`

tag:
git tag -a v${PKG_VERSION} -m v${PKG_VERSION}
git push --tag

tests:
hatch run test:test

lints:
hatch run test:lint

docs:
cd docs && make html
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
dynamic = ["version"]

[project.urls]
homepage = "https://github.com/dirmeier/sbijax"
Homepage = "https://github.com/dirmeier/sbijax"

[tool.hatch.metadata]
allow-direct-references = true
Expand Down
2 changes: 1 addition & 1 deletion sbijax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""sbijax: Simulation-based inference in JAX."""

__version__ = "0.2.0"
__version__ = "0.2.0.post0"

from sbijax._src.abc.smc_abc import SMCABC
from sbijax._src.scmpe import SCMPE
Expand Down
59 changes: 48 additions & 11 deletions sbijax/_src/_sne_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def simulate_data_and_possibly_append(
n_simulations=1000,
**kwargs,
):
"""Simulate data from the prior or posterior and append.
"""Simulate data and paarameters from the prior or posterior and append.
Args:
rng_key: a random key
Expand Down Expand Up @@ -76,7 +76,7 @@ def sample_posterior(self, rng_key, params, observable, *args, **kwargs):
**kwargs: keyword arguments
"""

def simulate_data(
def simulate_parameters(
self,
rng_key,
*,
Expand All @@ -85,27 +85,26 @@ def simulate_data(
n_simulations=1000,
**kwargs,
):
r"""Simulate data from the posterior or prior and append.
r"""Simulate parameters from the posterior or prior.
Args:
rng_key: a random key
params:a dictionary of neural network parameters. If None, will
draw from prior. If parameters given, will draw from amortized
posterior using 'observable;
observable: an observation. Needs to be gfiven if posterior draws
posterior using 'observable'.
observable: an observation. Needs to be given if posterior draws
are desired
n_simulations: number of newly simulated data
kwargs: dictionary of ey value pairs passed to `sample_posterior`
Returns:
a NamedTuple of two axis, y and theta
"""
sample_key, rng_key = jr.split(rng_key)
if params is None or len(params) == 0:
diagnostics = None
self.n_total_simulations += n_simulations
new_thetas = self.prior_sampler_fn(
seed=sample_key,
seed=rng_key,
sample_shape=(n_simulations,),
)
else:
Expand All @@ -117,7 +116,7 @@ def simulate_data(
if "n_samples" not in kwargs:
kwargs["n_samples"] = n_simulations
new_thetas, diagnostics = self.sample_posterior(
rng_key=sample_key,
rng_key=rng_key,
params=params,
observable=jnp.atleast_2d(observable),
**kwargs,
Expand All @@ -126,15 +125,53 @@ def simulate_data(
new_thetas = jr.permutation(perm_key, new_thetas)
new_thetas = new_thetas[:n_simulations, :]

simulate_key, rng_key = jr.split(rng_key)
new_obs = self.simulator_fn(seed=simulate_key, theta=new_thetas)
return new_thetas, diagnostics

def simulate_data(
self,
rng_key,
*,
params=None,
observable=None,
n_simulations=1000,
**kwargs,
):
r"""Simulate data from the posterior or prior and append.
Args:
rng_key: a random key
params:a dictionary of neural network parameters. If None, will
draw from prior. If parameters given, will draw from amortized
posterior using 'observable;
observable: an observation. Needs to be gfiven if posterior draws
are desired
n_simulations: number of newly simulated data
kwargs: dictionary of ey value pairs passed to `sample_posterior`
Returns:
a NamedTuple of two axis, y and theta
"""
theta_key, data_key = jr.split(rng_key)

new_thetas, diagnostics = self.simulate_parameters(
theta_key,
params=params,
observable=observable,
n_simulations=n_simulations,
**kwargs,
)

new_obs = self.simulate_observations(data_key, new_thetas)
chex.assert_shape(new_thetas, [n_simulations, None])
chex.assert_shape(new_obs, [n_simulations, None])

new_data = named_dataset(new_obs, new_thetas)

return new_data, diagnostics

def simulate_observations(self, rng_key, thetas):
new_obs = self.simulator_fn(seed=rng_key, theta=thetas)
return new_obs

@staticmethod
def as_iterators(
rng_key, data, batch_size, percentage_data_as_validation_set
Expand Down

0 comments on commit 3f7edcb

Please sign in to comment.