diff --git a/tests/unit/models/gpflow/test_utils.py b/tests/unit/models/gpflow/test_utils.py index dcefe1fdb2..770f21f7fd 100644 --- a/tests/unit/models/gpflow/test_utils.py +++ b/tests/unit/models/gpflow/test_utils.py @@ -143,6 +143,29 @@ def test_randomize_hyperparameters_randomizes_kernel_parameters_with_priors( assert len(np.unique(kernel.lengthscales)) == dim +@random_seed +@pytest.mark.parametrize("compile", [False, True]) +def test_randomize_hyperparameters_randomizes_kernel_parameters_with_unconstrained_priors( + dim: int, compile: bool +) -> None: + kernel = gpflow.kernels.RBF(variance=1.0, lengthscales=[0.2] * dim) + kernel.lengthscales = gpflow.Parameter(kernel.lengthscales, transform=tfp.bijectors.Exp()) + kernel.lengthscales.prior = tfp.distributions.Uniform( + tf.math.log(tf.constant(0.01, dtype=tf.float64)), + tf.math.log(tf.constant(10.0, dtype=tf.float64)), + ) + kernel.lengthscales.prior_on = gpflow.base.PriorOn.UNCONSTRAINED + + compiler = tf.function if compile else lambda x: x + compiler(randomize_hyperparameters)(kernel) + + npt.assert_allclose(1.0, kernel.variance) + npt.assert_array_equal(dim, kernel.lengthscales.shape) + npt.assert_array_less(kernel.lengthscales, [10.0] * dim) + npt.assert_raises(AssertionError, npt.assert_allclose, [0.2] * dim, kernel.lengthscales) + assert len(np.unique(kernel.lengthscales)) == dim + + @random_seed @pytest.mark.parametrize("compile", [False, True]) def test_randomize_hyperparameters_randomizes_kernel_parameters_with_const_priors( @@ -207,13 +230,13 @@ def test_randomize_hyperparameters_samples_from_constraints_when_given_prior_and kernel.lengthscales = gpflow.Parameter( kernel.lengthscales, transform=tfp.bijectors.Sigmoid(low=lower, high=upper) ) - kernel.lengthscales.prior = tfp.distributions.Uniform(low=10.0, high=100.0) + kernel.lengthscales.prior = tfp.distributions.Uniform(low=lower, high=upper / 2) kernel.variance.prior = tfp.distributions.LogNormal(loc=np.float64(-2.0), scale=np.float64(1.0)) randomize_hyperparameters(kernel) - npt.assert_array_less(kernel.lengthscales, [0.5] * dim) + npt.assert_array_less(kernel.lengthscales, [0.25] * dim) npt.assert_raises(AssertionError, npt.assert_allclose, [0.2] * dim, kernel.lengthscales) diff --git a/trieste/models/gpflow/utils.py b/trieste/models/gpflow/utils.py index 42598b325b..ec6f9aa779 100644 --- a/trieste/models/gpflow/utils.py +++ b/trieste/models/gpflow/utils.py @@ -14,11 +14,12 @@ from __future__ import annotations -from typing import Tuple, Union +from typing import Tuple, Union, Optional import gpflow import tensorflow as tf import tensorflow_probability as tfp +from gpflow.base import TensorData from ...data import Dataset from ...types import TensorType @@ -52,21 +53,13 @@ def assert_data_is_compatible(new_data: Dataset, existing_data: Dataset) -> None def randomize_hyperparameters(object: gpflow.Module) -> None: """ - Sets hyperparameters to random samples from their constrained domains or (if not constraints - are available) their prior distributions. + Sets hyperparameters to random samples from their prior distributions or (for Sigmoid + constraints with no priors) their constrained domains. :param object: Any gpflow Module. """ for param in object.trainable_parameters: - if isinstance(param.bijector, tfp.bijectors.Sigmoid): - sample = tf.random.uniform( - param.bijector.low.shape, - minval=param.bijector.low, - maxval=param.bijector.high, - dtype=param.bijector.low.dtype, - ) - param.assign(sample) - elif param.prior is not None: + if param.prior is not None: # handle constant priors for multi-dimensional parameters # Use python conditionals here to avoid creating tensorflow `tf.cond` ops, # i.e. using `len(param.shape)` instead of `tf.rank(param)`. @@ -76,6 +69,18 @@ def randomize_hyperparameters(object: gpflow.Module) -> None: sample = param.prior.sample(tf.shape(param)) else: sample = param.prior.sample() + if param.prior_on is gpflow.base.PriorOn.UNCONSTRAINED: + param.unconstrained_variable.assign(sample) + else: + param.assign(sample) + + elif isinstance(param.bijector, tfp.bijectors.Sigmoid): + sample = tf.random.uniform( + param.bijector.low.shape, + minval=param.bijector.low, + maxval=param.bijector.high, + dtype=param.bijector.low.dtype, + ) param.assign(sample)