Skip to content

Commit

Permalink
Handle unconstrained priors in randomize_hyperparameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Nov 24, 2023
1 parent aa94c86 commit cf74cb2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 14 deletions.
27 changes: 25 additions & 2 deletions tests/unit/models/gpflow/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,29 @@ def test_randomize_hyperparameters_randomizes_kernel_parameters_with_priors(
assert len(np.unique(kernel.lengthscales)) == dim


@random_seed
@pytest.mark.parametrize("compile", [False, True])
def test_randomize_hyperparameters_randomizes_kernel_parameters_with_unconstrained_priors(
dim: int, compile: bool
) -> None:
kernel = gpflow.kernels.RBF(variance=1.0, lengthscales=[0.2] * dim)
kernel.lengthscales = gpflow.Parameter(kernel.lengthscales, transform=tfp.bijectors.Exp())
kernel.lengthscales.prior = tfp.distributions.Uniform(
tf.math.log(tf.constant(0.01, dtype=tf.float64)),
tf.math.log(tf.constant(10.0, dtype=tf.float64)),
)
kernel.lengthscales.prior_on = gpflow.base.PriorOn.UNCONSTRAINED

compiler = tf.function if compile else lambda x: x
compiler(randomize_hyperparameters)(kernel)

npt.assert_allclose(1.0, kernel.variance)
npt.assert_array_equal(dim, kernel.lengthscales.shape)
npt.assert_array_less(kernel.lengthscales, [10.0] * dim)
npt.assert_raises(AssertionError, npt.assert_allclose, [0.2] * dim, kernel.lengthscales)
assert len(np.unique(kernel.lengthscales)) == dim


@random_seed
@pytest.mark.parametrize("compile", [False, True])
def test_randomize_hyperparameters_randomizes_kernel_parameters_with_const_priors(
Expand Down Expand Up @@ -207,13 +230,13 @@ def test_randomize_hyperparameters_samples_from_constraints_when_given_prior_and
kernel.lengthscales = gpflow.Parameter(
kernel.lengthscales, transform=tfp.bijectors.Sigmoid(low=lower, high=upper)
)
kernel.lengthscales.prior = tfp.distributions.Uniform(low=10.0, high=100.0)
kernel.lengthscales.prior = tfp.distributions.Uniform(low=lower, high=upper / 2)

kernel.variance.prior = tfp.distributions.LogNormal(loc=np.float64(-2.0), scale=np.float64(1.0))

randomize_hyperparameters(kernel)

npt.assert_array_less(kernel.lengthscales, [0.5] * dim)
npt.assert_array_less(kernel.lengthscales, [0.25] * dim)
npt.assert_raises(AssertionError, npt.assert_allclose, [0.2] * dim, kernel.lengthscales)


Expand Down
29 changes: 17 additions & 12 deletions trieste/models/gpflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@

from __future__ import annotations

from typing import Tuple, Union
from typing import Tuple, Union, Optional

import gpflow
import tensorflow as tf
import tensorflow_probability as tfp
from gpflow.base import TensorData

from ...data import Dataset
from ...types import TensorType
Expand Down Expand Up @@ -52,21 +53,13 @@ def assert_data_is_compatible(new_data: Dataset, existing_data: Dataset) -> None

def randomize_hyperparameters(object: gpflow.Module) -> None:
"""
Sets hyperparameters to random samples from their constrained domains or (if not constraints
are available) their prior distributions.
Sets hyperparameters to random samples from their prior distributions or (for Sigmoid
constraints with no priors) their constrained domains.
:param object: Any gpflow Module.
"""
for param in object.trainable_parameters:
if isinstance(param.bijector, tfp.bijectors.Sigmoid):
sample = tf.random.uniform(
param.bijector.low.shape,
minval=param.bijector.low,
maxval=param.bijector.high,
dtype=param.bijector.low.dtype,
)
param.assign(sample)
elif param.prior is not None:
if param.prior is not None:
# handle constant priors for multi-dimensional parameters
# Use python conditionals here to avoid creating tensorflow `tf.cond` ops,
# i.e. using `len(param.shape)` instead of `tf.rank(param)`.
Expand All @@ -76,6 +69,18 @@ def randomize_hyperparameters(object: gpflow.Module) -> None:
sample = param.prior.sample(tf.shape(param))
else:
sample = param.prior.sample()
if param.prior_on is gpflow.base.PriorOn.UNCONSTRAINED:
param.unconstrained_variable.assign(sample)
else:
param.assign(sample)

elif isinstance(param.bijector, tfp.bijectors.Sigmoid):
sample = tf.random.uniform(
param.bijector.low.shape,
minval=param.bijector.low,
maxval=param.bijector.high,
dtype=param.bijector.low.dtype,
)
param.assign(sample)


Expand Down

0 comments on commit cf74cb2

Please sign in to comment.