From 5247c1600dde68294014bcb2359b33e462c7d88e Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Thu, 2 Sep 2021 09:21:27 -0500 Subject: [PATCH] fix: Ensure TensorFlow backend Poisson compatibility with other backends (#1001) * Use differentiable TensorFlow operations to determine if Poisson(n=0 | lam=0) is being encountered and ensure a return value compatible with PyTorch and JAX * Use TensorFlow releases compatible with v2.X with a lower bound of v2.2.1 - Disallow TensorFlow v2.3.0 for accidental pinning of SciPy * Use TensorFlow Probability releases compatible with v0.X with a lower bound of v0.10.1 * Use tf.errors.InvalidArgumentError for compatibility across TensorFlow releases --- setup.py | 4 ++-- src/pyhf/tensor/tensorflow_backend.py | 23 ++++++++++++++++++++--- tests/test_tensor.py | 2 +- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 1bd14f53f5..2b56fc882f 100644 --- a/setup.py +++ b/setup.py @@ -3,8 +3,8 @@ extras_require = { 'shellcomplete': ['click_completion'], 'tensorflow': [ - 'tensorflow~=2.2.1', # TensorFlow minor releases are as volatile as major - 'tensorflow-probability~=0.10.1', + 'tensorflow~=2.2,>=2.2.1,!=2.3.0', # c.f. https://github.com/tensorflow/tensorflow/pull/40789 + 'tensorflow-probability~=0.10,>=0.10.1', ], 'torch': ['torch~=1.8'], 'jax': ['jax~=0.2.8', 'jaxlib~=0.1.58,!=0.1.68'], # c.f. Issue 1501 diff --git a/src/pyhf/tensor/tensorflow_backend.py b/src/pyhf/tensor/tensorflow_backend.py index d5c566475c..be41e8f488 100644 --- a/src/pyhf/tensor/tensorflow_backend.py +++ b/src/pyhf/tensor/tensorflow_backend.py @@ -2,6 +2,7 @@ import logging import tensorflow as tf import tensorflow_probability as tfp +from numpy import nan log = logging.getLogger(__name__) @@ -119,7 +120,7 @@ def tile(self, tensor_in, repeats): """ try: return tf.tile(tensor_in, repeats) - except tf.python.framework.errors_impl.InvalidArgumentError: + except tf.errors.InvalidArgumentError: shape = tf.shape(tensor_in).numpy().tolist() diff = len(repeats) - len(shape) if diff < 0: @@ -426,8 +427,15 @@ def poisson_logpdf(self, n, lam): TensorFlow Tensor: Value of the continuous approximation to log(Poisson(n|lam)) """ lam = self.astensor(lam) + # Guard against Poisson(n=0 | lam=0) + # c.f. https://github.com/scikit-hep/pyhf/issues/293 + valid_obs_given_rate = tf.logical_or( + tf.math.not_equal(lam, n), tf.math.not_equal(n, 0) + ) - return tfp.distributions.Poisson(lam).log_prob(n) + return tf.where( + valid_obs_given_rate, tfp.distributions.Poisson(lam).log_prob(n), nan + ) def poisson(self, n, lam): r""" @@ -457,8 +465,17 @@ def poisson(self, n, lam): TensorFlow Tensor: Value of the continuous approximation to Poisson(n|lam) """ lam = self.astensor(lam) + # Guard against Poisson(n=0 | lam=0) + # c.f. https://github.com/scikit-hep/pyhf/issues/293 + valid_obs_given_rate = tf.logical_or( + tf.math.not_equal(lam, n), tf.math.not_equal(n, 0) + ) - return tf.exp(tfp.distributions.Poisson(lam).log_prob(n)) + return tf.where( + valid_obs_given_rate, + tf.exp(tfp.distributions.Poisson(lam).log_prob(n)), + nan, + ) def normal_logpdf(self, x, mu, sigma): r""" diff --git a/tests/test_tensor.py b/tests/test_tensor.py index 612714b477..81a01f14c0 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -392,7 +392,7 @@ def test_tensor_tile(backend): ] if tb.name == 'tensorflow': - with pytest.raises(tf.python.framework.errors_impl.InvalidArgumentError): + with pytest.raises(tf.errors.InvalidArgumentError): tb.tile(tb.astensor([[[10, 20, 30]]]), (2, 1))