Skip to content

Commit

Permalink
feat: Tensorlib ravel functionality (#1147)
Browse files Browse the repository at this point in the history
* Add ravel method to tensorlib for flattened view of tensor
* Add tests for ravel
  • Loading branch information
kratsg authored Oct 27, 2020
1 parent 1e3f729 commit 638e5c7
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 10 deletions.
20 changes: 20 additions & 0 deletions src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,26 @@ def shape(self, tensor):
def reshape(self, tensor, newshape):
return jnp.reshape(tensor, newshape)

def ravel(self, tensor):
"""
Return a flattened view of the tensor, not a copy.
Example:
>>> import pyhf
>>> pyhf.set_backend("jax")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> pyhf.tensorlib.ravel(tensor)
DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float64)
Args:
tensor (Tensor): Tensor object
Returns:
`jax.interpreters.xla.DeviceArray`: A flattened array.
"""
return jnp.ravel(tensor)

def einsum(self, subscripts, *operands):
"""
Evaluates the Einstein summation convention on the operands.
Expand Down
20 changes: 20 additions & 0 deletions src/pyhf/tensor/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,26 @@ def shape(self, tensor):
def reshape(self, tensor, newshape):
return np.reshape(tensor, newshape)

def ravel(self, tensor):
"""
Return a flattened view of the tensor, not a copy.
Example:
>>> import pyhf
>>> pyhf.set_backend("numpy")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> pyhf.tensorlib.ravel(tensor)
array([1., 2., 3., 4., 5., 6.])
Args:
tensor (Tensor): Tensor object
Returns:
`numpy.ndarray`: A flattened array.
"""
return np.ravel(tensor)

def einsum(self, subscripts, *operands):
"""
Evaluates the Einstein summation convention on the operands.
Expand Down
20 changes: 20 additions & 0 deletions src/pyhf/tensor/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,26 @@ def reshape(self, tensor, newshape):
def shape(self, tensor):
return tuple(map(int, tensor.shape))

def ravel(self, tensor):
"""
Return a flattened view of the tensor, not a copy.
Example:
>>> import pyhf
>>> pyhf.set_backend("pytorch")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> pyhf.tensorlib.ravel(tensor)
tensor([1., 2., 3., 4., 5., 6.])
Args:
tensor (Tensor): Tensor object
Returns:
`torch.Tensor`: A flattened array.
"""
return tensor.view(-1)

def sum(self, tensor_in, axis=None):
return (
torch.sum(tensor_in)
Expand Down
21 changes: 21 additions & 0 deletions src/pyhf/tensor/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,27 @@ def shape(self, tensor):
def reshape(self, tensor, newshape):
return tf.reshape(tensor, newshape)

def ravel(self, tensor):
"""
Return a flattened view of the tensor, not a copy.
Example:
>>> import pyhf
>>> pyhf.set_backend("tensorflow")
>>> tensor = pyhf.tensorlib.astensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
>>> t_ravel = pyhf.tensorlib.ravel(tensor)
>>> print(t_ravel)
tf.Tensor([1. 2. 3. 4. 5. 6.], shape=(6,), dtype=float32)
Args:
tensor (Tensor): Tensor object
Returns:
`tf.Tensor`: A flattened array.
"""
return self.reshape(tensor, -1)

def divide(self, tensor_in_1, tensor_in_2):
return tf.divide(tensor_in_1, tensor_in_2)

Expand Down
45 changes: 35 additions & 10 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,41 @@ def test_simple_tensor_ops(backend):
assert tb.tolist(tb.conditional((a > b), lambda: a + b, lambda: a - b)) == -1.0


def test_tensor_where_scalar(backend):
tb = pyhf.tensorlib
assert tb.tolist(tb.where(tb.astensor([1, 0, 1], dtype="bool"), 1, 2)) == [1, 2, 1]


def test_tensor_where_tensor(backend):
tb = pyhf.tensorlib
assert (
tb.tolist(
tb.where(
tb.astensor([1, 0, 1], dtype="bool"),
tb.astensor([1, 1, 1]),
tb.astensor([2, 2, 2]),
)
)
== [1, 2, 1]
)


def test_tensor_ravel(backend):
tb = pyhf.tensorlib
assert (
tb.tolist(
tb.ravel(
tb.astensor(
[
[1, 2, 3],
[4, 5, 6],
]
)
)
)
) == [1, 2, 3, 4, 5, 6]


def test_complex_tensor_ops(backend):
tb = pyhf.tensorlib
assert tb.tolist(tb.outer(tb.astensor([1, 2, 3]), tb.astensor([4, 5, 6]))) == [
Expand All @@ -90,16 +125,6 @@ def test_complex_tensor_ops(backend):
1,
1,
]
assert (
tb.tolist(
tb.where(
tb.astensor([1, 0, 1], dtype="bool"),
tb.astensor([1, 1, 1]),
tb.astensor([2, 2, 2]),
)
)
== [1, 2, 1]
)


def test_ones(backend):
Expand Down

0 comments on commit 638e5c7

Please sign in to comment.