-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d6bd23c
commit 66a3c56
Showing
4 changed files
with
351 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
Empty file.
250 changes: 250 additions & 0 deletions
250
src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,250 @@ | ||
from abc import abstractmethod | ||
from typing import TYPE_CHECKING | ||
|
||
import genjax | ||
import jax | ||
import jax.numpy as jnp | ||
from genjax import Pytree | ||
from genjax.typing import FloatArray, PRNGKey, ScalarBool, ScalarFloat | ||
from tensorflow_probability.substrates import jax as tfp | ||
|
||
from b3d.chisight.dense.likelihoods.other_likelihoods import PythonMixturePixelModel | ||
from b3d.chisight.dynamic_object_model.likelihoods.kfold_image_kernel import ( | ||
_FIXED_COLOR_UNIFORM_WINDOW, | ||
truncated_laplace, | ||
) | ||
|
||
if TYPE_CHECKING: | ||
import tensorflow_probability.python.distributions.distribution as dist | ||
|
||
COLOR_MIN_VAL: float = 0.0 | ||
COLOR_MAX_VAL: float = 1.0 | ||
|
||
|
||
def is_unexplained(latent_color: FloatArray) -> ScalarBool: | ||
"""A heuristic to check if a pixel does not have any latent point that hits | ||
it. | ||
Args: | ||
latent_color (FloatArray): The latent color of the pixel. | ||
Returns: | ||
bool: True is none of the latent point hits the pixel, False otherwise. | ||
""" | ||
return jnp.any(latent_color < 0.0) | ||
|
||
|
||
@Pytree.dataclass | ||
class PixelColorDistribution(genjax.ExactDensity): | ||
"""An abstract class that defines the common interface for pixel color kernels.""" | ||
|
||
@abstractmethod | ||
def sample( | ||
self, key: PRNGKey, latent_color: FloatArray, *args, **kwargs | ||
) -> FloatArray: | ||
raise NotImplementedError | ||
|
||
def logpdf( | ||
self, observed_color: FloatArray, latent_color: FloatArray, *args, **kwargs | ||
) -> ScalarFloat: | ||
return self.logpdf_per_channel( | ||
observed_color, latent_color, *args, **kwargs | ||
).sum() | ||
|
||
@abstractmethod | ||
def logpdf_per_channel( | ||
self, observed_color: FloatArray, latent_color: FloatArray, *args, **kwargs | ||
) -> FloatArray: | ||
"""Return an array of logpdf values, one for each channel. This is useful | ||
for testing purposes.""" | ||
raise NotImplementedError | ||
|
||
|
||
@Pytree.dataclass | ||
class TruncatedLaplacePixelColorDistribution(PixelColorDistribution): | ||
"""A distribution that generates the color of a pixel from a truncated | ||
Laplace distribution centered around the latent color, with the spread | ||
controlled by color_scale. The support of the distribution is ([0, 1]^3). | ||
""" | ||
|
||
color_scale: ScalarFloat | ||
# the uniform window is used to wrapped the truncated laplace distribution | ||
# to ensure that the color generated is within the range of [0, 1] | ||
uniform_window_size: ScalarFloat = Pytree.static( | ||
default=_FIXED_COLOR_UNIFORM_WINDOW | ||
) | ||
|
||
def sample( | ||
self, key: PRNGKey, latent_color: FloatArray, *args, **kwargs | ||
) -> FloatArray: | ||
return jax.vmap( | ||
lambda k, color: truncated_laplace.sample( | ||
k, | ||
color, | ||
self.color_scale, | ||
COLOR_MIN_VAL, | ||
COLOR_MAX_VAL, | ||
self.uniform_window_size, | ||
), | ||
in_axes=(0, 0), | ||
)(jax.random.split(key, latent_color.shape[0]), latent_color) | ||
|
||
def logpdf_per_channel( | ||
self, observed_color: FloatArray, latent_color: FloatArray, *args, **kwargs | ||
) -> FloatArray: | ||
return jax.vmap( | ||
lambda obs, latent: truncated_laplace.logpdf( | ||
obs, | ||
latent, | ||
self.color_scale, | ||
COLOR_MIN_VAL, | ||
COLOR_MAX_VAL, | ||
self.uniform_window_size, | ||
), | ||
in_axes=(0, 0), | ||
)(observed_color, latent_color) | ||
|
||
|
||
@Pytree.dataclass | ||
class UniformPixelColorDistribution(PixelColorDistribution): | ||
"""A distribution that generates the color of a pixel from a uniform on the | ||
RGB space ([0, 1]^3). | ||
""" | ||
|
||
@property | ||
def _base_dist(self) -> "dist.Distribution": | ||
return tfp.distributions.Uniform(COLOR_MIN_VAL, COLOR_MAX_VAL) | ||
|
||
def sample(self, key: PRNGKey, *args, **kwargs) -> FloatArray: | ||
return self._base_dist.sample(seed=key, sample_shape=(3,)) | ||
|
||
def logpdf_per_channel( | ||
self, observed_color: FloatArray, *args, **kwargs | ||
) -> FloatArray: | ||
return self._base_dist.log_prob(observed_color) | ||
|
||
|
||
@Pytree.dataclass | ||
class MixturePixelColorDistribution(PixelColorDistribution): | ||
"""A distribution that generates the color of a pixel from a mixture of a | ||
truncated Laplace distribution centered around the latent color (inlier | ||
branch) and a uniform distribution (outlier branch). The mixture is | ||
controlled by the color_outlier_prob parameter. The support of the | ||
distribution is ([0, 1]^3). | ||
""" | ||
|
||
color_scale: ScalarFloat | ||
|
||
@property | ||
def _inlier_dist(self) -> PixelColorDistribution: | ||
return TruncatedLaplacePixelColorDistribution(self.color_scale) | ||
|
||
@property | ||
def _outlier_dist(self) -> PixelColorDistribution: | ||
return UniformPixelColorDistribution() | ||
|
||
@property | ||
def _mixture_dists(self) -> tuple[PixelColorDistribution, PixelColorDistribution]: | ||
return (self._inlier_dist, self._outlier_dist) | ||
|
||
def get_mix_ratio(self, color_outlier_prob: ScalarFloat) -> FloatArray: | ||
return jnp.array((1 - color_outlier_prob, color_outlier_prob)) | ||
|
||
def sample( | ||
self, | ||
key: PRNGKey, | ||
latent_color: FloatArray, | ||
color_outlier_prob: ScalarFloat, | ||
*args, | ||
**kwargs, | ||
) -> FloatArray: | ||
return PythonMixturePixelModel(self._mixture_dists).sample( | ||
key, self.get_mix_ratio(color_outlier_prob), [(latent_color,), ()] | ||
) | ||
|
||
def logpdf_per_channel( | ||
self, | ||
observed_color: FloatArray, | ||
latent_color: FloatArray, | ||
color_outlier_prob: ScalarFloat, | ||
*args, | ||
**kwargs, | ||
) -> FloatArray: | ||
# Since the mixture model class does not keep the per-channel information, | ||
# we have to redefine this method to allow for testing | ||
logprobs = [] | ||
for dist, prob in zip( | ||
self._mixture_dists, self.get_mix_ratio(color_outlier_prob) | ||
): | ||
logprobs.append( | ||
dist.logpdf_per_channel(observed_color, latent_color) + jnp.log(prob) | ||
) | ||
|
||
return jnp.logaddexp(*logprobs) | ||
|
||
|
||
@Pytree.dataclass | ||
class FullPixelColorDistribution(PixelColorDistribution): | ||
"""A distribution that generates the color of the pixel according to the | ||
following rule: | ||
if no latent point hits the pixel: | ||
color ~ uniform(0, 1) | ||
else: | ||
color ~ mixture( | ||
[truncated_laplace(latent_color; color_scale), uniform(0, 1)], | ||
[1 - color_outlier_prob, color_outlier_prob] | ||
) | ||
""" | ||
|
||
color_scale: ScalarFloat | ||
|
||
@property | ||
def _color_from_latent(self) -> PixelColorDistribution: | ||
return MixturePixelColorDistribution(self.color_scale) | ||
|
||
@property | ||
def _unexplained_color(self) -> PixelColorDistribution: | ||
return UniformPixelColorDistribution() | ||
|
||
def sample( | ||
self, | ||
key: PRNGKey, | ||
latent_color: FloatArray, | ||
color_outlier_prob: FloatArray, | ||
*args, | ||
**kwargs, | ||
) -> FloatArray: | ||
# Check if any of the latent point hits the current pixel | ||
is_explained = ~is_unexplained(latent_color) | ||
|
||
return jax.lax.cond( | ||
is_explained, | ||
self._color_from_latent.sample, # if pixel is being hit by a latent point | ||
self._unexplained_color.sample, # if no point hits current pixel | ||
# sample args | ||
key, | ||
latent_color, | ||
color_outlier_prob, | ||
) | ||
|
||
def logpdf_per_channel( | ||
self, | ||
observed_color: FloatArray, | ||
latent_color: FloatArray, | ||
color_outlier_prob: ScalarFloat, | ||
*args, | ||
**kwargs, | ||
) -> FloatArray: | ||
# Check if any of the latent point hits the current pixel | ||
is_explained = ~is_unexplained(latent_color) | ||
|
||
return jax.lax.cond( | ||
is_explained, | ||
self._color_from_latent.logpdf_per_channel, # if pixel is being hit by a latent point | ||
self._unexplained_color.logpdf_per_channel, # if no point hits current pixel | ||
# logpdf args | ||
observed_color, | ||
latent_color, | ||
color_outlier_prob, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
from functools import partial | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import pytest | ||
from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import ( | ||
COLOR_MAX_VAL, | ||
COLOR_MIN_VAL, | ||
FullPixelColorDistribution, | ||
MixturePixelColorDistribution, | ||
TruncatedLaplacePixelColorDistribution, | ||
UniformPixelColorDistribution, | ||
) | ||
from genjax.typing import FloatArray | ||
|
||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def generate_color_grid(n_grid_steps: int): | ||
"""Generate a grid of colors in very small interval to test that our logpdfs | ||
sum to 1. Since enumerating all color combination in 3 channels is infeasible, | ||
we take advantage of the fact that the channels are independent and only | ||
grid over the first channel. | ||
Args: | ||
n_grid_steps (int): The number of grid steps to generate | ||
Returns: | ||
FloatArray: A grid of colors with shape (n_grid_steps, 3), where the first | ||
channel is swept from 0 to 1 and the other two channels are fixed at 1. | ||
""" | ||
sweep_color_vals = jnp.linspace(COLOR_MIN_VAL, COLOR_MAX_VAL, n_grid_steps) | ||
fixed_color_vals = jnp.ones(n_grid_steps) | ||
return jnp.stack([sweep_color_vals, fixed_color_vals, fixed_color_vals], axis=-1) | ||
|
||
|
||
sample_kernels_to_test = [ | ||
(UniformPixelColorDistribution(), ()), | ||
(TruncatedLaplacePixelColorDistribution(0.1), ()), | ||
(MixturePixelColorDistribution(0.3), (0.5,)), # color_outlier_prob | ||
(FullPixelColorDistribution(0.5), (0.3,)), # color_outlier_prob | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("latent_color", [jnp.array([0.2, 0.5, 0.3]), jnp.zeros(3)]) | ||
@pytest.mark.parametrize("kernel_spec", sample_kernels_to_test) | ||
def test_logpdf_sum_to_1(kernel_spec, latent_color: FloatArray): | ||
kernel, additional_args = kernel_spec | ||
n_grid_steps = 10000000 | ||
color_grid = generate_color_grid(n_grid_steps) | ||
logpdf_per_channels = jax.vmap( | ||
lambda color: kernel.logpdf_per_channel(color, latent_color, *additional_args) | ||
)(color_grid) | ||
log_pmass = jax.scipy.special.logsumexp(logpdf_per_channels[..., 0]) - jnp.log( | ||
n_grid_steps | ||
) | ||
assert jnp.isclose(log_pmass, 0.0, atol=1e-3) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"latent_color", [jnp.array([0.25, 0.87, 0.31]), jnp.zeros(3), jnp.ones(3)] | ||
) | ||
@pytest.mark.parametrize("kernel_spec", sample_kernels_to_test) | ||
def test_sample_in_valid_color_range(kernel_spec, latent_color): | ||
kernel, additional_args = kernel_spec | ||
num_samples = 1000 | ||
keys = jax.random.split(jax.random.PRNGKey(0), num_samples) | ||
colors = jax.vmap(lambda key: kernel.sample(key, latent_color, *additional_args))( | ||
keys | ||
) | ||
assert colors.shape == (num_samples, 3) | ||
assert jnp.all(colors > 0) | ||
assert jnp.all(colors < 1) | ||
|
||
|
||
def test_relative_logpdf(): | ||
kernel = FullPixelColorDistribution(0.01) | ||
obs_color = jnp.array([0.0, 0.0, 1.0]) # a blue pixel | ||
|
||
# case 1: no color hit the pixel | ||
latent_color = -jnp.ones(3) # use -1 to denote invalid pixel | ||
logpdf_1 = kernel.logpdf(obs_color, latent_color, 0.2) | ||
logpdf_2 = kernel.logpdf(obs_color, latent_color, 0.8) | ||
# the logpdf should be the same because the outlier probability is not used | ||
# in the case when no color hit the pixel | ||
assert jnp.allclose(logpdf_1, logpdf_2) | ||
|
||
# case 2: a color hit the pixel, but the color is not close to the observed color | ||
latent_color = jnp.array([1.0, 0.5, 0.0]) | ||
logpdf_3 = kernel.logpdf(obs_color, latent_color, 0.2) | ||
logpdf_4 = kernel.logpdf(obs_color, latent_color, 0.8) | ||
# the pixel should be more likely to be an outlier | ||
assert logpdf_3 < logpdf_4 | ||
|
||
# case 3: a color hit the pixel, and the color is close to the observed color | ||
latent_color = jnp.array([0.0, 0.0, 0.9]) | ||
logpdf_5 = kernel.logpdf(obs_color, latent_color, 0.2) | ||
logpdf_6 = kernel.logpdf(obs_color, latent_color, 0.8) | ||
# the pixel should be more likely to be an inlier | ||
assert logpdf_5 > logpdf_6 | ||
# the score of the pixel should be higher when the color is closer | ||
assert logpdf_5 > logpdf_3 |