Skip to content

Commit

Permalink
fix some things and docu
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier committed Feb 29, 2024
1 parent d7e16ef commit 218a4b4
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 188 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
- [Neural Approximate Sufficient Statistics](https://arxiv.org/abs/2010.10079) (`SNASS`)
- [Neural Approximate Slice Sufficient Statistics](https://openreview.net/forum?id=jjzJ768iV1) (`SNASSS`)
- [Flow matching posterior estimation](https://openreview.net/forum?id=jjzJ768iV1) (`SFMPE`)
- [Consistency model posterior estimation](https://arxiv.org/abs/2312.05440) (`SCMPE`)

where the acronyms in parentheses denote the names of the methods in `sbijax`.

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- `Neural Approximate Sufficient Statistics <https://arxiv.org/abs/2010.10079>`_ (:code:`SNASS`)
- `Neural Approximate Slice Sufficient Statistics <https://openreview.net/forum?id=jjzJ768iV1>`_ (:code:`SNASSS`)
- `Flow matching posterior estimation <https://arxiv.org/abs/2305.17161>`_ (:code:`SFMPE`)
- `Consistency model posterior estimation <https://arxiv.org/abs/2312.05440>`_ (:code:`SCMPE`)

.. caution::

Expand Down
2 changes: 1 addition & 1 deletion docs/sbijax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ SFMPE
SCMPE
~~~~~

.. autoclass:: SFMPE
.. autoclass:: SCMPE
:members: fit, simulate_data_and_possibly_append, sample_posterior

SNASS
Expand Down
4 changes: 2 additions & 2 deletions examples/bivariate_gaussian_cfmpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def _c_skip(time):
return 1 / ((time - 0.001) ** 2 + 1)

def _c_out(time):
return 1.0 * (time - 0.001) / jnp.sqrt(1 + time ** 2)
return 1.0 * (time - 0.001) / jnp.sqrt(1 + time**2)

def _nn(theta, time, context, **kwargs):
ins = jnp.concatenate([theta, time, context], axis=-1)
outs = hk.nets.MLP([64, 64, dim])(ins)
Expand Down Expand Up @@ -68,7 +69,6 @@ def run():
optimizer=optimizer,
)


rng_key = jr.PRNGKey(23)
post_samples, _ = estim.sample_posterior(rng_key, params, y_observed)
print(post_samples)
Expand Down
115 changes: 49 additions & 66 deletions sbijax/_src/nn/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,48 @@
import haiku as hk
import jax
from jax import numpy as jnp
from jax.nn import glu
from scipy import integrate

__all__ = ["ConsistencyModel", "make_consistency_model"]
from sbijax._src.nn.continuous_normalizing_flow import _ResnetBlock

from sbijax._src.nn.make_resnet import _Resnet
__all__ = ["ConsistencyModel", "make_consistency_model"]


class ConsistencyModel(hk.Module):
"""Conditional continuous normalizing flow.
"""A consistency model.
Args:
n_dimension: the dimensionality of the modelled space
transform: a haiku module. The transform is a callable that has to
take as input arguments named 'theta', 'time', 'context' and
**kwargs. Theta, time and context are two-dimensional arrays
with the same batch dimensions.
t_min: minimal time point for ODE integration
t_max: maximal time point for ODE integration
"""

def __init__(self, n_dimension: int, transform: Callable, t_max=50):
"""Conditional continuous normalizing flow.
def __init__(
self,
n_dimension: int,
transform: Callable,
t_min: float = 0.001,
t_max: float = 50.0,
):
"""Construct a consistency model.
Args:
n_dimension: the dimensionality of the modelled space
transform: a haiku module. The transform is a callable that has to
take as input arguments named 'theta', 'time', 'context' and
**kwargs. Theta, time and context are two-dimensional arrays
with the same batch dimensions.
t_min: minimal time point for ODE integration
t_max: maximal time point for ODE integration
"""
super().__init__()
self._n_dimension = n_dimension
self._network = transform
self._t_max = t_max
self._t_min = t_min
self._base_distribution = distrax.Normal(jnp.zeros(n_dimension), 1.0)

def __call__(self, method, **kwargs):
Expand All @@ -50,16 +59,26 @@ def __call__(self, method, **kwargs):
"""
return getattr(self, method)(**kwargs)

def sample(self, context):
"""Sample from the pushforward.
def sample(self, context, **kwargs):
"""Sample from the consistency model.
Args:
context: array of conditioning variables
kwargs: keyword argumente like 'is_training'
"""
theta_0 = self._base_distribution.sample(
noise = self._base_distribution.sample(
seed=hk.next_rng_key(), sample_shape=(context.shape[0],)
)
y_hat = self.vector_field(theta_0, self._t_max, context)
y_hat = self.vector_field(noise, self._t_max, context, **kwargs)

noise = self._base_distribution.sample(
seed=hk.next_rng_key(), sample_shape=(y_hat.shape[0],)
)
tme = self._t_min + (self._t_max - self._t_min) / 2
noise = jnp.sqrt(jnp.square(tme) - jnp.square(self._t_min)) * noise
y_tme = y_hat + noise
y_hat = self.vector_field(y_tme, tme, context, **kwargs)

return y_hat

def vector_field(self, theta, time, context, **kwargs):
Expand All @@ -77,43 +96,6 @@ def vector_field(self, theta, time, context, **kwargs):
return self._network(theta=theta, time=time, context=context, **kwargs)


# pylint: disable=too-many-arguments
class _ResnetBlock(hk.Module):
"""A block for a 1d residual network."""

def __init__(
self,
hidden_size: int,
activation: Callable = jax.nn.relu,
dropout_rate: float = 0.2,
do_batch_norm: bool = False,
batch_norm_decay: float = 0.1,
):
super().__init__()
self.hidden_size = hidden_size
self.activation = activation
self.do_batch_norm = do_batch_norm
self.dropout_rate = dropout_rate
self.batch_norm_decay = batch_norm_decay

def __call__(self, inputs, context, is_training=False):
outputs = inputs
if self.do_batch_norm:
outputs = hk.BatchNorm(True, True, self.batch_norm_decay)(
outputs, is_training=is_training
)
outputs = hk.Linear(self.hidden_size)(outputs)
outputs = self.activation(outputs)
if is_training:
outputs = hk.dropout(
rng=hk.next_rng_key(), rate=self.dropout_rate, x=outputs
)
outputs = hk.Linear(self.hidden_size)(outputs)
context_proj = hk.Linear(inputs.shape[-1])(context)
outputs = glu(jnp.concatenate([outputs, context_proj], axis=-1))
return outputs + inputs


# pylint: disable=too-many-arguments
class _CMResnet(hk.Module):
"""A simplified 1-d residual network."""
Expand All @@ -127,8 +109,8 @@ def __init__(
dropout_rate: float = 0.0,
do_batch_norm: bool = False,
batch_norm_decay: float = 0.1,
eps: float = 0.001,
sigma_data:float = 1.0
t_min: float = 0.001,
sigma_data: float = 1.0,
):
super().__init__()
self.n_layers = n_layers
Expand All @@ -140,9 +122,9 @@ def __init__(
self.batch_norm_decay = batch_norm_decay
self.sigma_data = sigma_data
self.var_data = self.sigma_data**2
self.eps = eps
self.t_min = t_min

def __call__(self, theta, time, context, is_training=False, **kwargs):
def __call__(self, theta, time, context, is_training, **kwargs):
outputs = context
t_theta_embedding = jnp.concatenate(
[
Expand All @@ -164,19 +146,17 @@ def __call__(self, theta, time, context, is_training=False, **kwargs):
outputs = self.activation(outputs)
outputs = hk.Linear(self.n_dimension)(outputs)

# TODO(simon): how is sigma_data chosen automatically?
# in the meantime set it to 1 and use batch norm before
#outputs = hk.BatchNorm(True, True, self.batch_norm_decay)(outputs, is_training=is_training)
# TODO(simon): dan we choose sigma automatically?
out_skip = self._c_skip(time) * theta + self._c_out(time) * outputs
return out_skip

def _c_skip(self, time):
return self.var_data / ((time - self.eps) ** 2 + self.var_data)
return self.var_data / ((time - self.t_min) ** 2 + self.var_data)

def _c_out(self, time):
return (
self.sigma_data
* (time - self.eps)
* (time - self.t_min)
/ jnp.sqrt(self.var_data + time**2)
)

Expand All @@ -189,14 +169,13 @@ def make_consistency_model(
dropout_rate: float = 0.2,
do_batch_norm: bool = False,
batch_norm_decay: float = 0.2,
t_max: float=50,
epsilon=0.001,
sigma_data:float=1.0
t_min: float = 0.001,
t_max: float = 50.0,
sigma_data: float = 1.0,
):
"""Create a conditional continuous normalizing flow.
"""Create a consistency model.
The CCNF uses a residual network as transformer which is created
automatically.
The consistency model uses a residual network as score network.
Args:
n_dimension: dimensionality of modelled space
Expand All @@ -206,8 +185,12 @@ def make_consistency_model(
dropout_rate: dropout rate to use in resnet blocks
do_batch_norm: use batch normalization or not
batch_norm_decay: decay rate of EMA in batch norm layer
t_min: minimal time point for ODE integration
t_max: maximal time point for ODE integration
sigma_data: the standard deviation of the data :)
Returns:
returns a conditional continuous normalizing flow
returns a consistency model
"""

@hk.transform
Expand All @@ -220,10 +203,10 @@ def _cm(method, **kwargs):
do_batch_norm=do_batch_norm,
dropout_rate=dropout_rate,
batch_norm_decay=batch_norm_decay,
eps=epsilon,
t_min=t_min,
sigma_data=sigma_data,
)
cm = ConsistencyModel(n_dimension, nn, t_max=t_max)
cm = ConsistencyModel(n_dimension, nn, t_min=t_min, t_max=t_max)
return cm(method, **kwargs)

return _cm
Loading

0 comments on commit 218a4b4

Please sign in to comment.