Skip to content

Commit

Permalink
fix: Ensure TensorFlow backend Poisson compatibility with other backe…
Browse files Browse the repository at this point in the history
…nds (#1001)

* Use differentiable TensorFlow operations to determine if Poisson(n=0 | lam=0) is being encountered and ensure a return value compatible with PyTorch and JAX
* Use TensorFlow releases compatible with v2.X with a lower bound of v2.2.1
   - Disallow TensorFlow v2.3.0 for accidental pinning of SciPy
* Use TensorFlow Probability releases compatible with v0.X with a lower bound of v0.10.1
* Use tf.errors.InvalidArgumentError for compatibility across TensorFlow releases
  • Loading branch information
matthewfeickert authored Sep 2, 2021
1 parent b7a2c65 commit 5247c16
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
extras_require = {
'shellcomplete': ['click_completion'],
'tensorflow': [
'tensorflow~=2.2.1', # TensorFlow minor releases are as volatile as major
'tensorflow-probability~=0.10.1',
'tensorflow~=2.2,>=2.2.1,!=2.3.0', # c.f. https://github.com/tensorflow/tensorflow/pull/40789
'tensorflow-probability~=0.10,>=0.10.1',
],
'torch': ['torch~=1.8'],
'jax': ['jax~=0.2.8', 'jaxlib~=0.1.58,!=0.1.68'], # c.f. Issue 1501
Expand Down
23 changes: 20 additions & 3 deletions src/pyhf/tensor/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import tensorflow as tf
import tensorflow_probability as tfp
from numpy import nan

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -119,7 +120,7 @@ def tile(self, tensor_in, repeats):
"""
try:
return tf.tile(tensor_in, repeats)
except tf.python.framework.errors_impl.InvalidArgumentError:
except tf.errors.InvalidArgumentError:
shape = tf.shape(tensor_in).numpy().tolist()
diff = len(repeats) - len(shape)
if diff < 0:
Expand Down Expand Up @@ -426,8 +427,15 @@ def poisson_logpdf(self, n, lam):
TensorFlow Tensor: Value of the continuous approximation to log(Poisson(n|lam))
"""
lam = self.astensor(lam)
# Guard against Poisson(n=0 | lam=0)
# c.f. https://github.com/scikit-hep/pyhf/issues/293
valid_obs_given_rate = tf.logical_or(
tf.math.not_equal(lam, n), tf.math.not_equal(n, 0)
)

return tfp.distributions.Poisson(lam).log_prob(n)
return tf.where(
valid_obs_given_rate, tfp.distributions.Poisson(lam).log_prob(n), nan
)

def poisson(self, n, lam):
r"""
Expand Down Expand Up @@ -457,8 +465,17 @@ def poisson(self, n, lam):
TensorFlow Tensor: Value of the continuous approximation to Poisson(n|lam)
"""
lam = self.astensor(lam)
# Guard against Poisson(n=0 | lam=0)
# c.f. https://github.com/scikit-hep/pyhf/issues/293
valid_obs_given_rate = tf.logical_or(
tf.math.not_equal(lam, n), tf.math.not_equal(n, 0)
)

return tf.exp(tfp.distributions.Poisson(lam).log_prob(n))
return tf.where(
valid_obs_given_rate,
tf.exp(tfp.distributions.Poisson(lam).log_prob(n)),
nan,
)

def normal_logpdf(self, x, mu, sigma):
r"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def test_tensor_tile(backend):
]

if tb.name == 'tensorflow':
with pytest.raises(tf.python.framework.errors_impl.InvalidArgumentError):
with pytest.raises(tf.errors.InvalidArgumentError):
tb.tile(tb.astensor([[[10, 20, 30]]]), (2, 1))


Expand Down

0 comments on commit 5247c16

Please sign in to comment.