From 96348fff73fcdda8e15086dedb9f3dd75f6cb8fa Mon Sep 17 00:00:00 2001 From: Samuel Willis Date: Wed, 11 Oct 2023 13:09:08 +0100 Subject: [PATCH 1/2] change sampler --- trieste/models/gpflow/sampler.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/trieste/models/gpflow/sampler.py b/trieste/models/gpflow/sampler.py index a8316255d2..a88f8a9e1a 100644 --- a/trieste/models/gpflow/sampler.py +++ b/trieste/models/gpflow/sampler.py @@ -100,7 +100,9 @@ def __init__( :raise ValueError (or InvalidArgumentError): If ``sample_size`` is not positive. """ super().__init__(sample_size, model) - self._eps: Optional[tf.Variable] = None + self._eps = tf.Variable( + tf.zeros(shape=(sample_size, 0), dtype=tf.float64), shape=(sample_size, None), dtype=tf.float64 + ) self._qmc = qmc self._qmc_skip = qmc_skip @@ -141,9 +143,6 @@ def sample_eps() -> tf.Tensor: ) return normal_samples # [S, L] - if self._eps is None: - self._eps = tf.Variable(sample_eps()) - tf.cond( self._initialized, lambda: self._eps, From 54b5c58e748d5cd700b129af6f12dcfbc8a63d80 Mon Sep 17 00:00:00 2001 From: Samuel Willis Date: Fri, 13 Oct 2023 10:43:05 +0100 Subject: [PATCH 2/2] formatting --- trieste/models/gpflow/sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trieste/models/gpflow/sampler.py b/trieste/models/gpflow/sampler.py index a88f8a9e1a..a83b2f178a 100644 --- a/trieste/models/gpflow/sampler.py +++ b/trieste/models/gpflow/sampler.py @@ -101,7 +101,9 @@ def __init__( """ super().__init__(sample_size, model) self._eps = tf.Variable( - tf.zeros(shape=(sample_size, 0), dtype=tf.float64), shape=(sample_size, None), dtype=tf.float64 + tf.zeros(shape=(sample_size, 0), dtype=tf.float64), + shape=(sample_size, None), + dtype=tf.float64, ) self._qmc = qmc self._qmc_skip = qmc_skip