Skip to content
This repository has been archived by the owner on Nov 7, 2024. It is now read-only.

Implemented Cholesky decomposition to numpy and tensorflow backends #890

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
016d654
Added power to Jax backend and test
LuiGiovanni Dec 5, 2020
62a06c0
Added power function to the Jax backend and test.
LuiGiovanni Dec 5, 2020
43952dd
Fixed white space error.
LuiGiovanni Dec 5, 2020
b0f579e
Re-made the power test function.
LuiGiovanni Dec 5, 2020
e9f9869
Fixed typo
LuiGiovanni Dec 5, 2020
9e16a9f
Testing out numpy square.
LuiGiovanni Dec 5, 2020
81b4d8b
Fixed issues with the assertion in test, should work now.
LuiGiovanni Dec 5, 2020
195f691
Added NotImplementedError function and it's respective test for Chols…
LuiGiovanni Dec 6, 2020
737f071
Fixed line too long error
LuiGiovanni Dec 6, 2020
0fbe8fb
Removing changes for different branches.
LuiGiovanni Dec 6, 2020
8c37656
Revert "Fixed issues with the assertion in test, should work now."
LuiGiovanni Dec 6, 2020
a894960
Fixed a few pylint issues, should work now
LuiGiovanni Dec 6, 2020
3ba747f
Fixed typo in the error message
LuiGiovanni Dec 6, 2020
a112ebe
Merge branch 'master' into CholeskyDecomposition
alewis Dec 7, 2020
f9fb40d
Renamed cholesky function, to a shorter name
LuiGiovanni Dec 7, 2020
4b4ece6
Merge branch 'CholeskyDecomposition' of https://github.com/LuiGiovann…
LuiGiovanni Dec 7, 2020
0c3817d
Renamed cholesky function from cholesky_decomposition to cholesky
LuiGiovanni Dec 7, 2020
aedfe2a
Implemented Cholesky to numpy and tensorflow with their respective tests
LuiGiovanni Dec 23, 2020
13026b5
Fixed pylint issues
LuiGiovanni Dec 23, 2020
11eba6a
Implemented Cholesky decomposition to tensorflow, numpy & pytorch bac…
LuiGiovanni Dec 25, 2020
4da6eac
Fixed minor pylint issues
LuiGiovanni Dec 25, 2020
6548989
Merge branch 'master' into CholeskyDecomposition
mganahl Dec 28, 2020
66b5b94
Merge branch 'master' into CholeskyDecomposition
mganahl Dec 28, 2020
dab2060
requested changes made removed unnecessary arguments and better matrix
LuiGiovanni Dec 28, 2020
e6b9c08
Merge branch 'CholeskyDecomposition' of https://github.com/LuiGiovann…
LuiGiovanni Dec 28, 2020
367ce3d
Fixed commented testing functions
LuiGiovanni Dec 28, 2020
92c931a
Merge branch 'master' into CholeskyDecomposition
mganahl Jan 4, 2021
a51cb9a
Merge branch 'master' into CholeskyDecomposition
mganahl Jan 5, 2021
2b269e0
Merge branch 'master' into CholeskyDecomposition
mganahl Jan 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions tensornetwork/backends/numpy/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,28 @@ def rq(
r = np.reshape(r, list(left_dims) + [center_dim])
q = np.reshape(q, [center_dim] + list(right_dims))
return r, q


def cholesky(
tf: Any,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename tf to np (the user passes the numpy module here)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya

tensor: Tensor,
pivot_axis: int,
non_negative_diagonal: bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove non_negative argument

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(it only makes sense for QR)

) -> Tuple[Tensor, Tensor]:
"""
Computes the Cholesky decomposition of a tensor

See tensornetwork.backends.tensorflow.decompositions for details.
"""
left_dims = tf.shape(tensor)[:pivot_axis]
right_dims = tf.shape(tensor)[pivot_axis:]
tensor = tf.reshape(tensor,
[tf.reduce_prod(left_dims),
tf.reduce_prod(right_dims)])
L = tf.linalg.cholesky(tensor)
if non_negative_diagonal:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary. The cholesky decomposition should already return a lower-triangular matrix with a real, positive diagonal.

phases = tf.math.sign(tf.linalg.diag_part(L))
L = phases[:, None] * L
center_dim = tf.shape(L)[1]
L = tf.reshape(L, tf.concat([left_dims, [center_dim]], axis=-1))
return L
7 changes: 7 additions & 0 deletions tensornetwork/backends/numpy/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def test_qr(self):
q, r = decompositions.qr(np, random_matrix, 1, non_negative_diagonal)
self.assertAllClose(q.dot(r), random_matrix)

def test_cholesky(self):
#Assured positive-definite hermitian matrixs
random_matrix = [[1.0, 0], [0, 1.0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the identity matrix as "random" matrix is not an ideal test case. You can generate a positive definite matrix e.g. by

A = np.random.rand(D,D)
A  = A @ A.T.conj()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw. also test this for all supported dtypes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do np.random.rand(D, D).astype(dtype) to get your fav type. if it's imaginary you might need to initialize the real and imaginary parts separately.

put np.random_seed(10) (or any other fixed integer) before the rand so that the matrix is fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no dtype defined in any of the tests, should I define one myself?

for non_negative_diagonal in [True, False]:
L = decompositions.cholesky(tf, random_matrix, 1, non_negative_diagonal)
self.assertAllClose(np.linalg.cholesky(random_matrix), L)

def test_max_singular_values(self):
random_matrix = np.random.rand(10, 10)
unitary1, _, unitary2 = np.linalg.svd(random_matrix)
Expand Down
23 changes: 23 additions & 0 deletions tensornetwork/backends/pytorch/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,29 @@ def svd(
return u, s, vh, s_rest


def cholesky(
torch: Any,
tensor: Tensor,
pivot_axis: int,
non_negative_diagonal: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove non_negative_diagonal (see above)

) -> Tuple[Tensor, Tensor]:
"""
Computes the Cholesky decomposition of a tensor

See tensornetwork.backends.tensorflow.decompositions for details.
"""
left_dims = np.shape(tensor)[:pivot_axis]
right_dims = np.shape(tensor)[pivot_axis:]

tensor = torch.reshape(tensor, (np.prod(left_dims), np.prod(right_dims)))
L = torch.cholesky(tensor)
if non_negative_diagonal:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove (see above)

phases = torch.sign(torch.diagonal(L))
L = phases[:, None] * L
center_dim = L.shape[1]
L = torch.reshape(L, list(left_dims) + [center_dim])
return L

def qr(
torch: Any,
tensor: Tensor,
Expand Down
7 changes: 7 additions & 0 deletions tensornetwork/backends/pytorch/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ def test_expected_shapes_rq():
assert r.shape == (2, 3, 6)
assert q.shape == (6, 4, 5)

def test_cholesky():
#Assured positive-definite hermitian matrix
random_matrix = np.array([[1.0, 0], [0, 1.0]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use better test matrix (see above)

random_matrix = torch.from_numpy(random_matrix)
for non_negative_diagonal in [True, False]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

L = decompositions.cholesky(torch, random_matrix, 1, non_negative_diagonal)
np.testing.assert_allclose(torch.cholesky(random_matrix), L)

def test_rq():
random_matrix = torch.rand([10, 10], dtype=torch.float64)
Expand Down
25 changes: 25 additions & 0 deletions tensornetwork/backends/tensorflow/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,28 @@ def rq(
r = tf.reshape(r, tf.concat([left_dims, [center_dim]], axis=-1))
q = tf.reshape(q, tf.concat([[center_dim], right_dims], axis=-1))
return r, q


def cholesky(
tf: Any,
tensor: Tensor,
pivot_axis: int,
non_negative_diagonal: bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

) -> Tuple[Tensor, Tensor]:
""" Computes de cholesky decomposition of a tensor.

Returns the Cholesky decomposition of a tensor which we treat as a
square matrix
"""
left_dims = tf.shape(tensor)[:pivot_axis]
right_dims = tf.shape(tensor)[pivot_axis:]
tensor = tf.reshape(tensor,
[tf.reduce_prod(left_dims),
tf.reduce_prod(right_dims)])
L = tf.linalg.cholesky(tensor)
if non_negative_diagonal:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

phases = tf.math.sign(tf.linalg.diag_part(L))
L = phases[:, None] * L
center_dim = tf.shape(L)[1]
L = tf.reshape(L, tf.concat([left_dims, [center_dim]], axis=-1))
return L
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a newline here (at the end of the file)

7 changes: 7 additions & 0 deletions tensornetwork/backends/tensorflow/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ def test_qr(self):
q, r = decompositions.qr(tf, random_matrix, 1, non_negative_diagonal)
self.assertAllClose(tf.tensordot(q, r, ([1], [0])), random_matrix)

def test_cholesky(self):
#Assured positive-definite hermitian matrix
random_matrix = [[1.0, 0], [0, 1.0]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modify test, see above

for non_negative_diagonal in [True, False]:
L = decompositions.cholesky(tf, random_matrix, 1, non_negative_diagonal)
self.assertAllClose(np.linalg.cholesky(random_matrix), L)

def test_rq_defun(self):
random_matrix = np.random.rand(10, 10)
for non_negative_diagonal in [True, False]:
Expand Down