diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index 1ffb378c14..a6c872dcfb 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -304,6 +304,26 @@ def shape(self, tensor): def reshape(self, tensor, newshape): return jnp.reshape(tensor, newshape) + def ravel(self, tensor): + """ + Return a flattened view of the tensor, not a copy. + + Example: + + >>> import pyhf + >>> pyhf.set_backend("jax") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> pyhf.tensorlib.ravel(tensor) + DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float64) + + Args: + tensor (Tensor): Tensor object + + Returns: + `jax.interpreters.xla.DeviceArray`: A flattened array. + """ + return jnp.ravel(tensor) + def einsum(self, subscripts, *operands): """ Evaluates the Einstein summation convention on the operands. diff --git a/src/pyhf/tensor/numpy_backend.py b/src/pyhf/tensor/numpy_backend.py index ee0d1460ad..8b22460021 100644 --- a/src/pyhf/tensor/numpy_backend.py +++ b/src/pyhf/tensor/numpy_backend.py @@ -289,6 +289,26 @@ def shape(self, tensor): def reshape(self, tensor, newshape): return np.reshape(tensor, newshape) + def ravel(self, tensor): + """ + Return a flattened view of the tensor, not a copy. + + Example: + + >>> import pyhf + >>> pyhf.set_backend("numpy") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> pyhf.tensorlib.ravel(tensor) + array([1., 2., 3., 4., 5., 6.]) + + Args: + tensor (Tensor): Tensor object + + Returns: + `numpy.ndarray`: A flattened array. + """ + return np.ravel(tensor) + def einsum(self, subscripts, *operands): """ Evaluates the Einstein summation convention on the operands. diff --git a/src/pyhf/tensor/pytorch_backend.py b/src/pyhf/tensor/pytorch_backend.py index 935200221a..2a13ff6516 100644 --- a/src/pyhf/tensor/pytorch_backend.py +++ b/src/pyhf/tensor/pytorch_backend.py @@ -189,6 +189,26 @@ def reshape(self, tensor, newshape): def shape(self, tensor): return tuple(map(int, tensor.shape)) + def ravel(self, tensor): + """ + Return a flattened view of the tensor, not a copy. + + Example: + + >>> import pyhf + >>> pyhf.set_backend("pytorch") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> pyhf.tensorlib.ravel(tensor) + tensor([1., 2., 3., 4., 5., 6.]) + + Args: + tensor (Tensor): Tensor object + + Returns: + `torch.Tensor`: A flattened array. + """ + return tensor.view(-1) + def sum(self, tensor_in, axis=None): return ( torch.sum(tensor_in) diff --git a/src/pyhf/tensor/tensorflow_backend.py b/src/pyhf/tensor/tensorflow_backend.py index d8f3778fc1..46abcf1ab3 100644 --- a/src/pyhf/tensor/tensorflow_backend.py +++ b/src/pyhf/tensor/tensorflow_backend.py @@ -256,6 +256,27 @@ def shape(self, tensor): def reshape(self, tensor, newshape): return tf.reshape(tensor, newshape) + def ravel(self, tensor): + """ + Return a flattened view of the tensor, not a copy. + + Example: + + >>> import pyhf + >>> pyhf.set_backend("tensorflow") + >>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + >>> t_ravel = pyhf.tensorlib.ravel(tensor) + >>> print(t_ravel) + tf.Tensor([1. 2. 3. 4. 5. 6.], shape=(6,), dtype=float32) + + Args: + tensor (Tensor): Tensor object + + Returns: + `tf.Tensor`: A flattened array. + """ + return self.reshape(tensor, -1) + def divide(self, tensor_in_1, tensor_in_2): return tf.divide(tensor_in_1, tensor_in_2) diff --git a/tests/test_tensor.py b/tests/test_tensor.py index a986d042cc..60de75af00 100644 --- a/tests/test_tensor.py +++ b/tests/test_tensor.py @@ -66,6 +66,41 @@ def test_simple_tensor_ops(backend): assert tb.tolist(tb.conditional((a > b), lambda: a + b, lambda: a - b)) == -1.0 +def test_tensor_where_scalar(backend): + tb = pyhf.tensorlib + assert tb.tolist(tb.where(tb.astensor([1, 0, 1], dtype="bool"), 1, 2)) == [1, 2, 1] + + +def test_tensor_where_tensor(backend): + tb = pyhf.tensorlib + assert ( + tb.tolist( + tb.where( + tb.astensor([1, 0, 1], dtype="bool"), + tb.astensor([1, 1, 1]), + tb.astensor([2, 2, 2]), + ) + ) + == [1, 2, 1] + ) + + +def test_tensor_ravel(backend): + tb = pyhf.tensorlib + assert ( + tb.tolist( + tb.ravel( + tb.astensor( + [ + [1, 2, 3], + [4, 5, 6], + ] + ) + ) + ) + ) == [1, 2, 3, 4, 5, 6] + + def test_complex_tensor_ops(backend): tb = pyhf.tensorlib assert tb.tolist(tb.outer(tb.astensor([1, 2, 3]), tb.astensor([4, 5, 6]))) == [ @@ -90,16 +125,6 @@ def test_complex_tensor_ops(backend): 1, 1, ] - assert ( - tb.tolist( - tb.where( - tb.astensor([1, 0, 1], dtype="bool"), - tb.astensor([1, 1, 1]), - tb.astensor([2, 2, 2]), - ) - ) - == [1, 2, 1] - ) def test_ones(backend):