-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Impl consistency model posterior estimation
- Loading branch information
Showing
8 changed files
with
623 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
""" | ||
Example using consistency model posterior estimation on a bivariate Gaussian | ||
""" | ||
|
||
import distrax | ||
import haiku as hk | ||
import matplotlib.pyplot as plt | ||
import optax | ||
import seaborn as sns | ||
from jax import numpy as jnp | ||
from jax import random as jr | ||
|
||
from sbijax import SCMPE | ||
from sbijax.nn import ConsistencyModel | ||
|
||
|
||
def prior_model_fns(): | ||
p = distrax.Independent(distrax.Normal(jnp.zeros(2), jnp.ones(2)), 1) | ||
return p.sample, p.log_prob | ||
|
||
|
||
def simulator_fn(seed, theta): | ||
p = distrax.Normal(jnp.zeros_like(theta), 1.0) | ||
y = theta + p.sample(seed=seed) | ||
return y | ||
|
||
|
||
def make_model(dim): | ||
@hk.transform | ||
def _mlp(method, **kwargs): | ||
def _c_skip(time): | ||
return 1 / ((time - 0.001) ** 2 + 1) | ||
|
||
def _c_out(time): | ||
return 1.0 * (time - 0.001) / jnp.sqrt(1 + time ** 2) | ||
def _nn(theta, time, context, **kwargs): | ||
ins = jnp.concatenate([theta, time, context], axis=-1) | ||
outs = hk.nets.MLP([64, 64, dim])(ins) | ||
out_skip = _c_skip(time) * theta + _c_out(time) * outs | ||
return out_skip | ||
|
||
cm = ConsistencyModel(dim, _nn) | ||
return cm(method, **kwargs) | ||
|
||
return _mlp | ||
|
||
|
||
def run(): | ||
y_observed = jnp.array([2.0, -2.0]) | ||
|
||
prior_simulator_fn, prior_logdensity_fn = prior_model_fns() | ||
fns = (prior_simulator_fn, prior_logdensity_fn), simulator_fn | ||
|
||
estim = SCMPE(fns, make_model(2)) | ||
optimizer = optax.adam(1e-3) | ||
|
||
data, params = None, {} | ||
for i in range(2): | ||
data, _ = estim.simulate_data_and_possibly_append( | ||
jr.fold_in(jr.PRNGKey(1), i), | ||
params=params, | ||
observable=y_observed, | ||
data=data, | ||
) | ||
params, info = estim.fit( | ||
jr.fold_in(jr.PRNGKey(2), i), | ||
data=data, | ||
optimizer=optimizer, | ||
) | ||
|
||
|
||
rng_key = jr.PRNGKey(23) | ||
post_samples, _ = estim.sample_posterior(rng_key, params, y_observed) | ||
print(post_samples) | ||
fig, axes = plt.subplots(2) | ||
for i, ax in enumerate(axes): | ||
sns.histplot(post_samples[:, i], color="darkblue", ax=ax) | ||
ax.set_xlim([-3.0, 3.0]) | ||
sns.despine() | ||
plt.tight_layout() | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
from typing import Callable | ||
|
||
import distrax | ||
import haiku as hk | ||
import jax | ||
from jax import numpy as jnp | ||
from jax.nn import glu | ||
from scipy import integrate | ||
|
||
__all__ = ["ConsistencyModel", "make_consistency_model"] | ||
|
||
from sbijax._src.nn.make_resnet import _Resnet | ||
|
||
|
||
class ConsistencyModel(hk.Module): | ||
"""Conditional continuous normalizing flow. | ||
Args: | ||
n_dimension: the dimensionality of the modelled space | ||
transform: a haiku module. The transform is a callable that has to | ||
take as input arguments named 'theta', 'time', 'context' and | ||
**kwargs. Theta, time and context are two-dimensional arrays | ||
with the same batch dimensions. | ||
""" | ||
|
||
def __init__(self, n_dimension: int, transform: Callable, t_max=50): | ||
"""Conditional continuous normalizing flow. | ||
Args: | ||
n_dimension: the dimensionality of the modelled space | ||
transform: a haiku module. The transform is a callable that has to | ||
take as input arguments named 'theta', 'time', 'context' and | ||
**kwargs. Theta, time and context are two-dimensional arrays | ||
with the same batch dimensions. | ||
""" | ||
super().__init__() | ||
self._n_dimension = n_dimension | ||
self._network = transform | ||
self._t_max = t_max | ||
self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0) | ||
|
||
def __call__(self, method, **kwargs): | ||
"""Aplpy the flow. | ||
Args: | ||
method (str): method to call | ||
Keyword Args: | ||
keyword arguments for the called method: | ||
""" | ||
return getattr(self, method)(**kwargs) | ||
|
||
def sample(self, context): | ||
"""Sample from the pushforward. | ||
Args: | ||
context: array of conditioning variables | ||
""" | ||
theta_0 = self._base_distribution.sample( | ||
seed=hk.next_rng_key(), sample_shape=(context.shape[0],) | ||
) | ||
y_hat = self.vector_field(theta_0, self._t_max, context) | ||
return y_hat | ||
|
||
def vector_field(self, theta, time, context, **kwargs): | ||
"""Compute the vector field. | ||
Args: | ||
theta: array of parameters | ||
time: time variables | ||
context: array of conditioning variables | ||
Keyword Args: | ||
keyword arguments that aer passed tothe neural network | ||
""" | ||
time = jnp.full((theta.shape[0], 1), time) | ||
return self._network(theta=theta, time=time, context=context, **kwargs) | ||
|
||
|
||
# pylint: disable=too-many-arguments | ||
class _ResnetBlock(hk.Module): | ||
"""A block for a 1d residual network.""" | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
activation: Callable = jax.nn.relu, | ||
dropout_rate: float = 0.2, | ||
do_batch_norm: bool = False, | ||
batch_norm_decay: float = 0.1, | ||
): | ||
super().__init__() | ||
self.hidden_size = hidden_size | ||
self.activation = activation | ||
self.do_batch_norm = do_batch_norm | ||
self.dropout_rate = dropout_rate | ||
self.batch_norm_decay = batch_norm_decay | ||
|
||
def __call__(self, inputs, context, is_training=False): | ||
outputs = inputs | ||
if self.do_batch_norm: | ||
outputs = hk.BatchNorm(True, True, self.batch_norm_decay)( | ||
outputs, is_training=is_training | ||
) | ||
outputs = hk.Linear(self.hidden_size)(outputs) | ||
outputs = self.activation(outputs) | ||
if is_training: | ||
outputs = hk.dropout( | ||
rng=hk.next_rng_key(), rate=self.dropout_rate, x=outputs | ||
) | ||
outputs = hk.Linear(self.hidden_size)(outputs) | ||
context_proj = hk.Linear(inputs.shape[-1])(context) | ||
outputs = glu(jnp.concatenate([outputs, context_proj], axis=-1)) | ||
return outputs + inputs | ||
|
||
|
||
# pylint: disable=too-many-arguments | ||
class _CMResnet(hk.Module): | ||
"""A simplified 1-d residual network.""" | ||
|
||
def __init__( | ||
self, | ||
n_layers: int, | ||
n_dimension: int, | ||
hidden_size: int, | ||
activation: Callable = jax.nn.relu, | ||
dropout_rate: float = 0.0, | ||
do_batch_norm: bool = False, | ||
batch_norm_decay: float = 0.1, | ||
eps: float = 0.001, | ||
sigma_data:float = 1.0 | ||
): | ||
super().__init__() | ||
self.n_layers = n_layers | ||
self.n_dimension = n_dimension | ||
self.hidden_size = hidden_size | ||
self.activation = activation | ||
self.do_batch_norm = do_batch_norm | ||
self.dropout_rate = dropout_rate | ||
self.batch_norm_decay = batch_norm_decay | ||
self.sigma_data = sigma_data | ||
self.var_data = self.sigma_data**2 | ||
self.eps = eps | ||
|
||
def __call__(self, theta, time, context, is_training=False, **kwargs): | ||
outputs = context | ||
t_theta_embedding = jnp.concatenate( | ||
[ | ||
hk.Linear(self.n_dimension)(theta), | ||
hk.Linear(self.n_dimension)(time), | ||
], | ||
axis=-1, | ||
) | ||
outputs = hk.Linear(self.hidden_size)(outputs) | ||
outputs = self.activation(outputs) | ||
for _ in range(self.n_layers): | ||
outputs = _ResnetBlock( | ||
hidden_size=self.hidden_size, | ||
activation=self.activation, | ||
dropout_rate=self.dropout_rate, | ||
do_batch_norm=self.do_batch_norm, | ||
batch_norm_decay=self.batch_norm_decay, | ||
)(outputs, context=t_theta_embedding, is_training=is_training) | ||
outputs = self.activation(outputs) | ||
outputs = hk.Linear(self.n_dimension)(outputs) | ||
|
||
# TODO(simon): how is sigma_data chosen automatically? | ||
# in the meantime set it to 1 and use batch norm before | ||
#outputs = hk.BatchNorm(True, True, self.batch_norm_decay)(outputs, is_training=is_training) | ||
out_skip = self._c_skip(time) * theta + self._c_out(time) * outputs | ||
return out_skip | ||
|
||
def _c_skip(self, time): | ||
return self.var_data / ((time - self.eps) ** 2 + self.var_data) | ||
|
||
def _c_out(self, time): | ||
return ( | ||
self.sigma_data | ||
* (time - self.eps) | ||
/ jnp.sqrt(self.var_data + time**2) | ||
) | ||
|
||
|
||
def make_consistency_model( | ||
n_dimension: int, | ||
n_layers: int = 2, | ||
hidden_size: int = 64, | ||
activation: Callable = jax.nn.tanh, | ||
dropout_rate: float = 0.2, | ||
do_batch_norm: bool = False, | ||
batch_norm_decay: float = 0.2, | ||
t_max: float=50, | ||
epsilon=0.001, | ||
sigma_data:float=1.0 | ||
): | ||
"""Create a conditional continuous normalizing flow. | ||
The CCNF uses a residual network as transformer which is created | ||
automatically. | ||
Args: | ||
n_dimension: dimensionality of modelled space | ||
n_layers: number of resnet blocks | ||
hidden_size: sizes of hidden layers for each resnet block | ||
activation: a jax activation function | ||
dropout_rate: dropout rate to use in resnet blocks | ||
do_batch_norm: use batch normalization or not | ||
batch_norm_decay: decay rate of EMA in batch norm layer | ||
Returns: | ||
returns a conditional continuous normalizing flow | ||
""" | ||
|
||
@hk.transform | ||
def _cm(method, **kwargs): | ||
nn = _CMResnet( | ||
n_layers=n_layers, | ||
n_dimension=n_dimension, | ||
hidden_size=hidden_size, | ||
activation=activation, | ||
do_batch_norm=do_batch_norm, | ||
dropout_rate=dropout_rate, | ||
batch_norm_decay=batch_norm_decay, | ||
eps=epsilon, | ||
sigma_data=sigma_data, | ||
) | ||
cm = ConsistencyModel(n_dimension, nn, t_max=t_max) | ||
return cm(method, **kwargs) | ||
|
||
return _cm |
Oops, something went wrong.