Skip to content

Commit

Permalink
Pixel Kernels (1/2): Color Kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
horizon-blue committed Sep 9, 2024
1 parent d6bd23c commit 66a3c56
Show file tree
Hide file tree
Showing 4 changed files with 351 additions and 0 deletions.
File renamed without changes.
Empty file.
250 changes: 250 additions & 0 deletions src/b3d/chisight/gen3d/pixel_kernels/pixel_color_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
from abc import abstractmethod
from typing import TYPE_CHECKING

import genjax
import jax
import jax.numpy as jnp
from genjax import Pytree
from genjax.typing import FloatArray, PRNGKey, ScalarBool, ScalarFloat
from tensorflow_probability.substrates import jax as tfp

from b3d.chisight.dense.likelihoods.other_likelihoods import PythonMixturePixelModel
from b3d.chisight.dynamic_object_model.likelihoods.kfold_image_kernel import (
_FIXED_COLOR_UNIFORM_WINDOW,
truncated_laplace,
)

if TYPE_CHECKING:
import tensorflow_probability.python.distributions.distribution as dist

COLOR_MIN_VAL: float = 0.0
COLOR_MAX_VAL: float = 1.0


def is_unexplained(latent_color: FloatArray) -> ScalarBool:
"""A heuristic to check if a pixel does not have any latent point that hits
it.
Args:
latent_color (FloatArray): The latent color of the pixel.
Returns:
bool: True is none of the latent point hits the pixel, False otherwise.
"""
return jnp.any(latent_color < 0.0)


@Pytree.dataclass
class PixelColorDistribution(genjax.ExactDensity):
"""An abstract class that defines the common interface for pixel color kernels."""

@abstractmethod
def sample(
self, key: PRNGKey, latent_color: FloatArray, *args, **kwargs
) -> FloatArray:
raise NotImplementedError

def logpdf(
self, observed_color: FloatArray, latent_color: FloatArray, *args, **kwargs
) -> ScalarFloat:
return self.logpdf_per_channel(
observed_color, latent_color, *args, **kwargs
).sum()

@abstractmethod
def logpdf_per_channel(
self, observed_color: FloatArray, latent_color: FloatArray, *args, **kwargs
) -> FloatArray:
"""Return an array of logpdf values, one for each channel. This is useful
for testing purposes."""
raise NotImplementedError


@Pytree.dataclass
class TruncatedLaplacePixelColorDistribution(PixelColorDistribution):
"""A distribution that generates the color of a pixel from a truncated
Laplace distribution centered around the latent color, with the spread
controlled by color_scale. The support of the distribution is ([0, 1]^3).
"""

color_scale: ScalarFloat
# the uniform window is used to wrapped the truncated laplace distribution
# to ensure that the color generated is within the range of [0, 1]
uniform_window_size: ScalarFloat = Pytree.static(
default=_FIXED_COLOR_UNIFORM_WINDOW
)

def sample(
self, key: PRNGKey, latent_color: FloatArray, *args, **kwargs
) -> FloatArray:
return jax.vmap(
lambda k, color: truncated_laplace.sample(
k,
color,
self.color_scale,
COLOR_MIN_VAL,
COLOR_MAX_VAL,
self.uniform_window_size,
),
in_axes=(0, 0),
)(jax.random.split(key, latent_color.shape[0]), latent_color)

def logpdf_per_channel(
self, observed_color: FloatArray, latent_color: FloatArray, *args, **kwargs
) -> FloatArray:
return jax.vmap(
lambda obs, latent: truncated_laplace.logpdf(
obs,
latent,
self.color_scale,
COLOR_MIN_VAL,
COLOR_MAX_VAL,
self.uniform_window_size,
),
in_axes=(0, 0),
)(observed_color, latent_color)


@Pytree.dataclass
class UniformPixelColorDistribution(PixelColorDistribution):
"""A distribution that generates the color of a pixel from a uniform on the
RGB space ([0, 1]^3).
"""

@property
def _base_dist(self) -> "dist.Distribution":
return tfp.distributions.Uniform(COLOR_MIN_VAL, COLOR_MAX_VAL)

def sample(self, key: PRNGKey, *args, **kwargs) -> FloatArray:
return self._base_dist.sample(seed=key, sample_shape=(3,))

def logpdf_per_channel(
self, observed_color: FloatArray, *args, **kwargs
) -> FloatArray:
return self._base_dist.log_prob(observed_color)


@Pytree.dataclass
class MixturePixelColorDistribution(PixelColorDistribution):
"""A distribution that generates the color of a pixel from a mixture of a
truncated Laplace distribution centered around the latent color (inlier
branch) and a uniform distribution (outlier branch). The mixture is
controlled by the color_outlier_prob parameter. The support of the
distribution is ([0, 1]^3).
"""

color_scale: ScalarFloat

@property
def _inlier_dist(self) -> PixelColorDistribution:
return TruncatedLaplacePixelColorDistribution(self.color_scale)

@property
def _outlier_dist(self) -> PixelColorDistribution:
return UniformPixelColorDistribution()

@property
def _mixture_dists(self) -> tuple[PixelColorDistribution, PixelColorDistribution]:
return (self._inlier_dist, self._outlier_dist)

def get_mix_ratio(self, color_outlier_prob: ScalarFloat) -> FloatArray:
return jnp.array((1 - color_outlier_prob, color_outlier_prob))

def sample(
self,
key: PRNGKey,
latent_color: FloatArray,
color_outlier_prob: ScalarFloat,
*args,
**kwargs,
) -> FloatArray:
return PythonMixturePixelModel(self._mixture_dists).sample(
key, self.get_mix_ratio(color_outlier_prob), [(latent_color,), ()]
)

def logpdf_per_channel(
self,
observed_color: FloatArray,
latent_color: FloatArray,
color_outlier_prob: ScalarFloat,
*args,
**kwargs,
) -> FloatArray:
# Since the mixture model class does not keep the per-channel information,
# we have to redefine this method to allow for testing
logprobs = []
for dist, prob in zip(
self._mixture_dists, self.get_mix_ratio(color_outlier_prob)
):
logprobs.append(
dist.logpdf_per_channel(observed_color, latent_color) + jnp.log(prob)
)

return jnp.logaddexp(*logprobs)


@Pytree.dataclass
class FullPixelColorDistribution(PixelColorDistribution):
"""A distribution that generates the color of the pixel according to the
following rule:
if no latent point hits the pixel:
color ~ uniform(0, 1)
else:
color ~ mixture(
[truncated_laplace(latent_color; color_scale), uniform(0, 1)],
[1 - color_outlier_prob, color_outlier_prob]
)
"""

color_scale: ScalarFloat

@property
def _color_from_latent(self) -> PixelColorDistribution:
return MixturePixelColorDistribution(self.color_scale)

@property
def _unexplained_color(self) -> PixelColorDistribution:
return UniformPixelColorDistribution()

def sample(
self,
key: PRNGKey,
latent_color: FloatArray,
color_outlier_prob: FloatArray,
*args,
**kwargs,
) -> FloatArray:
# Check if any of the latent point hits the current pixel
is_explained = ~is_unexplained(latent_color)

return jax.lax.cond(
is_explained,
self._color_from_latent.sample, # if pixel is being hit by a latent point
self._unexplained_color.sample, # if no point hits current pixel
# sample args
key,
latent_color,
color_outlier_prob,
)

def logpdf_per_channel(
self,
observed_color: FloatArray,
latent_color: FloatArray,
color_outlier_prob: ScalarFloat,
*args,
**kwargs,
) -> FloatArray:
# Check if any of the latent point hits the current pixel
is_explained = ~is_unexplained(latent_color)

return jax.lax.cond(
is_explained,
self._color_from_latent.logpdf_per_channel, # if pixel is being hit by a latent point
self._unexplained_color.logpdf_per_channel, # if no point hits current pixel
# logpdf args
observed_color,
latent_color,
color_outlier_prob,
)
101 changes: 101 additions & 0 deletions tests/gen3d/test_pixel_color_kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from functools import partial

import jax
import jax.numpy as jnp
import pytest
from b3d.chisight.gen3d.pixel_kernels.pixel_color_kernels import (
COLOR_MAX_VAL,
COLOR_MIN_VAL,
FullPixelColorDistribution,
MixturePixelColorDistribution,
TruncatedLaplacePixelColorDistribution,
UniformPixelColorDistribution,
)
from genjax.typing import FloatArray


@partial(jax.jit, static_argnums=(0,))
def generate_color_grid(n_grid_steps: int):
"""Generate a grid of colors in very small interval to test that our logpdfs
sum to 1. Since enumerating all color combination in 3 channels is infeasible,
we take advantage of the fact that the channels are independent and only
grid over the first channel.
Args:
n_grid_steps (int): The number of grid steps to generate
Returns:
FloatArray: A grid of colors with shape (n_grid_steps, 3), where the first
channel is swept from 0 to 1 and the other two channels are fixed at 1.
"""
sweep_color_vals = jnp.linspace(COLOR_MIN_VAL, COLOR_MAX_VAL, n_grid_steps)
fixed_color_vals = jnp.ones(n_grid_steps)
return jnp.stack([sweep_color_vals, fixed_color_vals, fixed_color_vals], axis=-1)


sample_kernels_to_test = [
(UniformPixelColorDistribution(), ()),
(TruncatedLaplacePixelColorDistribution(0.1), ()),
(MixturePixelColorDistribution(0.3), (0.5,)), # color_outlier_prob
(FullPixelColorDistribution(0.5), (0.3,)), # color_outlier_prob
]


@pytest.mark.parametrize("latent_color", [jnp.array([0.2, 0.5, 0.3]), jnp.zeros(3)])
@pytest.mark.parametrize("kernel_spec", sample_kernels_to_test)
def test_logpdf_sum_to_1(kernel_spec, latent_color: FloatArray):
kernel, additional_args = kernel_spec
n_grid_steps = 10000000
color_grid = generate_color_grid(n_grid_steps)
logpdf_per_channels = jax.vmap(
lambda color: kernel.logpdf_per_channel(color, latent_color, *additional_args)
)(color_grid)
log_pmass = jax.scipy.special.logsumexp(logpdf_per_channels[..., 0]) - jnp.log(
n_grid_steps
)
assert jnp.isclose(log_pmass, 0.0, atol=1e-3)


@pytest.mark.parametrize(
"latent_color", [jnp.array([0.25, 0.87, 0.31]), jnp.zeros(3), jnp.ones(3)]
)
@pytest.mark.parametrize("kernel_spec", sample_kernels_to_test)
def test_sample_in_valid_color_range(kernel_spec, latent_color):
kernel, additional_args = kernel_spec
num_samples = 1000
keys = jax.random.split(jax.random.PRNGKey(0), num_samples)
colors = jax.vmap(lambda key: kernel.sample(key, latent_color, *additional_args))(
keys
)
assert colors.shape == (num_samples, 3)
assert jnp.all(colors > 0)
assert jnp.all(colors < 1)


def test_relative_logpdf():
kernel = FullPixelColorDistribution(0.01)
obs_color = jnp.array([0.0, 0.0, 1.0]) # a blue pixel

# case 1: no color hit the pixel
latent_color = -jnp.ones(3) # use -1 to denote invalid pixel
logpdf_1 = kernel.logpdf(obs_color, latent_color, 0.2)
logpdf_2 = kernel.logpdf(obs_color, latent_color, 0.8)
# the logpdf should be the same because the outlier probability is not used
# in the case when no color hit the pixel
assert jnp.allclose(logpdf_1, logpdf_2)

# case 2: a color hit the pixel, but the color is not close to the observed color
latent_color = jnp.array([1.0, 0.5, 0.0])
logpdf_3 = kernel.logpdf(obs_color, latent_color, 0.2)
logpdf_4 = kernel.logpdf(obs_color, latent_color, 0.8)
# the pixel should be more likely to be an outlier
assert logpdf_3 < logpdf_4

# case 3: a color hit the pixel, and the color is close to the observed color
latent_color = jnp.array([0.0, 0.0, 0.9])
logpdf_5 = kernel.logpdf(obs_color, latent_color, 0.2)
logpdf_6 = kernel.logpdf(obs_color, latent_color, 0.8)
# the pixel should be more likely to be an inlier
assert logpdf_5 > logpdf_6
# the score of the pixel should be higher when the color is closer
assert logpdf_5 > logpdf_3

0 comments on commit 66a3c56

Please sign in to comment.