Skip to content

Commit

Permalink
minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve committed Mar 1, 2024
1 parent 0d48c54 commit f35902e
Showing 1 changed file with 77 additions and 48 deletions.
125 changes: 77 additions & 48 deletions src/elisa/infer/nested_sampling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Copy and adapt from
"""Copy and adapt from
https://github.com/pyro-ppl/numpyro/raw/master/numpyro/contrib/nested_sampling.py
"""
# Copyright Contributors to the Pyro project.
Expand All @@ -14,14 +13,17 @@
import numpyro
import numpyro.distributions as dist
import tensorflow_probability.substrates.jax as tfp

from jax import random
from numpyro.handlers import reparam, seed, trace
from numpyro.infer import Predictive
from numpyro.infer.reparam import Reparam
from numpyro.infer.util import _guess_max_plate_nesting, _validate_model, log_density
from numpyro.infer.util import (
_guess_max_plate_nesting,
_validate_model,
log_density,
)

__all__ = ["NestedSampler"]
__all__ = ['NestedSampler']

tfpd = tfp.distributions

Expand All @@ -34,10 +36,13 @@ def uniform_reparam_transform(d):
"""
if isinstance(d, dist.TransformedDistribution):
outer_transform = dist.transforms.ComposeTransform(d.transforms)
return lambda q: outer_transform(uniform_reparam_transform(d.base_dist)(q))
return lambda q: outer_transform(
uniform_reparam_transform(d.base_dist)(q)
)

if isinstance(
d, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)
d,
(dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution),
):
return lambda q: uniform_reparam_transform(d.base_dist)(q)

Expand All @@ -63,7 +68,9 @@ def transform(q):
@uniform_reparam_transform.register(dist.CategoricalLogits)
@uniform_reparam_transform.register(dist.CategoricalProbs)
def _(d):
return lambda q: jnp.sum(jnp.cumsum(d.probs, axis=-1) < q[..., None], axis=-1)
return lambda q: jnp.sum(
jnp.cumsum(d.probs, axis=-1) < q[..., None], axis=-1
)


@uniform_reparam_transform.register(dist.Dirichlet)
Expand All @@ -90,15 +97,20 @@ class UniformReparam(Reparam):
"""

def __call__(self, name, fn, obs):
assert obs is None, "TransformReparam does not support observe statements"
assert (
obs is None
), 'TransformReparam does not support observe statements'
shape = fn.shape()
fn, expand_shape, event_dim = self._unwrap(fn)
transform = uniform_reparam_transform(fn)
tiny = jnp.finfo(jnp.result_type(float)).tiny

x = numpyro.sample(
"{}_base".format(name),
dist.Uniform(tiny, 1).expand(shape).to_event(event_dim).mask(False),
f'{name}_base',
dist.Uniform(tiny, 1)
.expand(shape)
.to_event(event_dim)
.mask(False),
)
# Simulate a numpyro.deterministic() site.
return None, transform(x)
Expand Down Expand Up @@ -136,17 +148,17 @@ class NestedSampler:
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> import numpyro.distributions as prior
>>> from numpyro.contrib.nested_sampling import NestedSampler
>>> true_coefs = jnp.array([1., 2., 3.])
>>> data = random.normal(random.PRNGKey(0), (2000, 3))
>>> labels = dist.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(1))
>>> labels = prior.Bernoulli(logits=(true_coefs * data).sum(-1)).sample(random.PRNGKey(1))
>>>
>>> def model(data, labels):
... coefs = numpyro.sample('coefs', dist.Normal(0, 1).expand([3]))
... intercept = numpyro.sample('intercept', dist.Normal(0., 10.))
... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)),
... coefs = numpyro.sample('coefs', prior.Normal(0, 1).expand([3]))
... intercept = numpyro.sample('intercept', prior.Normal(0., 10.))
... return numpyro.sample('y', prior.Bernoulli(logits=(coefs * data + intercept).sum(-1)),
... obs=labels)
>>>
>>> ns = NestedSampler(model)
Expand Down Expand Up @@ -178,6 +190,7 @@ def __init__(
# jaxns is import here because it runs jax program when importing
logging.disable(logging.INFO) # temporarily disable jaxns logging
import jaxns

logging.disable(logging.NOTSET)
self._jaxns = jaxns

Expand All @@ -193,31 +206,36 @@ def run(self, rng_key, *args, **kwargs):

rng_sampling, rng_predictive = random.split(rng_key)
# reparam the model so that latent sites have Uniform(0, 1) priors
prototype_trace = trace(seed(self.model, rng_key)).get_trace(*args, **kwargs)
prototype_trace = trace(seed(self.model, rng_key)).get_trace(
*args, **kwargs
)
param_names = [
site["name"]
site['name']
for site in prototype_trace.values()
if site["type"] == "sample"
and not site["is_observed"]
and site["infer"].get("enumerate", "") != "parallel"
if site['type'] == 'sample'
and not site['is_observed']
and site['infer'].get('enumerate', '') != 'parallel'
]
deterministics = [
site["name"]
site['name']
for site in prototype_trace.values()
if site["type"] == "deterministic"
if site['type'] == 'deterministic'
]
reparam_model = reparam(
self.model, config={k: UniformReparam() for k in param_names}
)

# enable enumerate if needed
has_enum = any(
site["type"] == "sample"
and site["infer"].get("enumerate", "") == "parallel"
site['type'] == 'sample'
and site['infer'].get('enumerate', '') == 'parallel'
for site in prototype_trace.values()
)
if has_enum:
from numpyro.contrib.funsor import enum, log_density as log_density_
from numpyro.contrib.funsor import (
enum,
log_density as log_density_,
)

max_plate_nesting = _guess_max_plate_nesting(prototype_trace)
_validate_model(prototype_trace)
Expand All @@ -231,20 +249,20 @@ def run(self, rng_key, *args, **kwargs):
\tparams = dict({})\n
\treturn log_density_(reparam_model, args, kwargs, params)[0]
""".format(
", ".join([f"{name}_base" for name in param_names]),
", ".join([f"{name}_base={name}_base" for name in param_names]),
', '.join([f'{name}_base' for name in param_names]),
', '.join([f'{name}_base={name}_base' for name in param_names]),
)
exec(loglik_fn_def, locals(), local_dict)
loglik_fn = local_dict["loglik_fn"]
loglik_fn = local_dict['loglik_fn']

# use NestedSampler with identity prior chain
def prior_model():
params = []
for name in param_names:
shape = prototype_trace[name]["fn"].shape()
shape = prototype_trace[name]['fn'].shape()
param = yield jaxns.Prior(
tfpd.Uniform(low=jnp.zeros(shape), high=jnp.ones(shape)),
name=name + "_base",
name=name + '_base',
)
params.append(param)
return tuple(params)
Expand Down Expand Up @@ -281,7 +299,9 @@ def prior_model():
rng_sampling,
term_cond=jaxns.TerminationCondition(**self.termination_kwargs),
)
results = ns.to_results(termination_reason=termination_reason, state=state)
results = ns.to_results(
termination_reason=termination_reason, state=state
)

# transform base samples back to original domains
# Here we only transform the first valid num_samples samples
Expand All @@ -308,7 +328,7 @@ def get_samples(self, rng_key, num_samples=None):

if self._results is None:
raise RuntimeError(
"NestedSampler.run(...) method should be called first to obtain results."
'NestedSampler.run(...) method should be called first to obtain results.'
)

weighted_samples, sample_weights = self.get_weighted_samples()
Expand All @@ -319,7 +339,11 @@ def get_samples(self, rng_key, num_samples=None):
num_samples = int(num_samples)

return jaxns.resample(
rng_key, weighted_samples, sample_weights, S=num_samples, replace=True
rng_key,
weighted_samples,
sample_weights,
S=num_samples,
replace=True,
)

def get_weighted_samples(self):
Expand All @@ -328,7 +352,7 @@ def get_weighted_samples(self):
"""
if self._results is None:
raise RuntimeError(
"NestedSampler.run(...) method should be called first to obtain results."
'NestedSampler.run(...) method should be called first to obtain results.'
)

return self._results.samples, self._results.log_dp_mean
Expand All @@ -341,7 +365,7 @@ def print_summary(self):

if self._results is None:
raise RuntimeError(
"NestedSampler.run(...) method should be called first to obtain results."
'NestedSampler.run(...) method should be called first to obtain results.'
)
jaxns.summary(self._results)

Expand All @@ -354,7 +378,7 @@ def diagnostics(self):

if self._results is None:
raise RuntimeError(
"NestedSampler.run(...) method should be called first to obtain results."
'NestedSampler.run(...) method should be called first to obtain results.'
)
jaxns.plot_diagnostics(self._results)
jaxns.plot_cornerplot(self._results)
Expand All @@ -365,25 +389,25 @@ def reparam_loglike(model, rng_key, *args, **kwargs):
# reparam the model so that latent sites have Uniform(0, 1) priors
prototype_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs)
param_names = [
site["name"]
site['name']
for site in prototype_trace.values()
if site["type"] == "sample"
and not site["is_observed"]
and site["infer"].get("enumerate", "") != "parallel"
if site['type'] == 'sample'
and not site['is_observed']
and site['infer'].get('enumerate', '') != 'parallel'
]
deterministics = [
site["name"]
site['name']
for site in prototype_trace.values()
if site["type"] == "deterministic"
if site['type'] == 'deterministic'
]
reparam_model = reparam(
model, config={k: UniformReparam() for k in param_names}
)

# enable enumerate if needed
has_enum = any(
site["type"] == "sample"
and site["infer"].get("enumerate", "") == "parallel"
site['type'] == 'sample'
and site['infer'].get('enumerate', '') == 'parallel'
for site in prototype_trace.values()
)
if has_enum:
Expand All @@ -401,8 +425,8 @@ def reparam_loglike(model, rng_key, *args, **kwargs):
\tparams = dict({})\n
\treturn log_density_(reparam_model, args, kwargs, params)[0]
""".format(
", ".join([f"{name}_base" for name in param_names]),
", ".join([f"{name}_base={name}_base" for name in param_names]),
', '.join([f'{name}_base' for name in param_names]),
', '.join([f'{name}_base={name}_base' for name in param_names]),
)
exec(loglik_fn_def, locals(), local_dict)

Expand All @@ -411,4 +435,9 @@ def transform_back(samples):
reparam_model, samples, return_sites=param_names + deterministics
)
return predictive(rng_predictive, *args, **kwargs)
return local_dict["loglik_fn"], transform_back, [f"{name}_base" for name in param_names]

return (
local_dict['loglik_fn'],
transform_back,
[f'{name}_base' for name in param_names],
)

0 comments on commit f35902e

Please sign in to comment.